import numpy as np
import pandas as pd
import re
from collections import Counter, defaultdict
from scipy import stats
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
import seaborn as sns
import json
import joblib
from itertools import combinations
import RNA
import os
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from sklearn.cluster import DBSCAN
import Levenshtein
import tqdm

class StemPatternMiner:
    def __init__(self, output_dir=".", max_stem_length=5, min_stem_length=1, threads=1):
        self.output_dir = output_dir
        self.max_stem_length = max_stem_length
        self.min_stem_length = min_stem_length
        self.threads = threads
        self.pattern_stats = {}
        self.significant_patterns = []
        self.stem_characteristics = defaultdict(list)
    
    def cluster_similar_motifs_fast(self, stems, similarity_threshold=0.7):
        """快速聚类方法 - 基于长度和配对比例的简化聚类"""
            
        print(f"Fast clustering {len(stems)} stems with threshold {similarity_threshold}...")
        
        # 按长度和配对比例分组
        stem_groups = defaultdict(list)
        
        for stem in stems:
            pattern = stem['stem_structure']
            length = len(pattern)
            paired_ratio = (pattern.count('(') + pattern.count(')')) / length if length > 0 else 0
            
            # 创建分组键：长度范围和配对比例范围
            length_group = (length // 3) * 3  # 按3bp分组
            ratio_group = round(paired_ratio, 1)  # 按0.1精度分组
            
            group_key = (length_group, ratio_group)
            stem_groups[group_key].append(stem)
        
        clustered_stems = []
        
        # 在每个组内选择代表性模式
        for group_key, group_stems in stem_groups.items():
            if len(group_stems) == 1:
                clustered_stems.extend(group_stems)
                continue
                
            # 在组内找到最常见的模式作为代表
            pattern_counts = Counter(stem['stem_structure'] for stem in group_stems)
            representative_pattern = pattern_counts.most_common(1)[0][0]
            
            # 更新所有stem使用代表性模式
            for stem in group_stems:
                new_stem = stem.copy()
                new_stem['stem_structure'] = representative_pattern
                new_stem['original_pattern'] = stem['stem_structure']
                new_stem['cluster_id'] = hash(representative_pattern)
                clustered_stems.append(new_stem)
        
        original_count = len(set(stem['stem_structure'] for stem in stems))
        clustered_count = len(set(stem['stem_structure'] for stem in clustered_stems))
        reduction_ratio = (original_count - clustered_count) / original_count * 100
        
        print(f"Fast clustering reduced {original_count} to {clustered_count} patterns ({reduction_ratio:.1f}% reduction)")
        
        return clustered_stems


    def cluster_similar_motifs_parallel(self, stems, similarity_threshold=0.7):
        """并行聚类相似的stem结构 - 添加错误处理"""
            
        print(f"Clustering {len(stems)} stems with similarity threshold {similarity_threshold}...")
        print(f"Using {self.threads} thread(s) for clustering...")
        
        # 如果stem数量太多，使用快速聚类
        if len(stems) > 50000:
            print("Large dataset detected, using fast clustering method...")
            return self.cluster_similar_motifs_fast(stems, similarity_threshold)
        
        try:
            # 提取所有独特的结构模式
            unique_patterns = list(set(stem['stem_structure'] for stem in stems))
            
            if len(unique_patterns) <= 1:
                return stems  # 不需要聚类
            
            # 并行计算相似性矩阵
            similarity_matrix = self.calculate_structure_similarity_parallel(unique_patterns)
            
            # 使用DBSCAN进行聚类
            clustering = DBSCAN(eps=1-similarity_threshold, min_samples=1, metric='precomputed')
            labels = clustering.fit_predict(1 - similarity_matrix)
            
            # 检查是否有有效的聚类
            unique_labels = set(labels)
            if len(unique_labels) == len(unique_patterns):
                print("No clustering occurred, all patterns are unique")
                return stems
            
            # 并行计算聚类代表
            cluster_representatives = self.find_cluster_representatives_parallel(
                labels, unique_patterns, stems
            )
            
            # 并行更新stem的结构模式
            clustered_stems = self.update_stems_with_clusters_parallel(
                stems, cluster_representatives, labels, unique_patterns
            )
            
            # 统计聚类效果
            original_count = len(unique_patterns)
            clustered_count = len(cluster_representatives)
            reduction_ratio = (original_count - clustered_count) / original_count * 100
            print(f"Clustering reduced {original_count} unique patterns to {clustered_count} clusters ({reduction_ratio:.1f}% reduction)")
            
            return clustered_stems
            
        except Exception as e:
            print(f"Clustering failed with error: {e}")
            print("Falling back to fast clustering method...")
            return self.cluster_similar_motifs_fast(stems, similarity_threshold)
    
    def calculate_structure_similarity_parallel(self, patterns):
        """并行计算结构相似性矩阵 - 修复tkinter问题"""
        n = len(patterns)
        similarity_matrix = np.eye(n)
        
        # 准备任务
        tasks = []
        for i in range(n):
            for j in range(i+1, n):
                tasks.append((i, j, patterns[i], patterns[j]))
        
        # 使用简单的进度显示，避免tqdm的tkinter问题
        print(f"Calculating similarities for {len(tasks)} pattern pairs...")
        
        # 并行计算相似性
        if self.threads > 1 and len(tasks) > 1000:
            results = [None] * len(tasks)
            chunk_size = max(1000, len(tasks) // (self.threads * 4))  # 更小的chunk size
            
            with ThreadPoolExecutor(max_workers=self.threads) as executor:
                # 分批提交任务，避免内存问题
                for chunk_start in range(0, len(tasks), chunk_size):
                    chunk_end = min(chunk_start + chunk_size, len(tasks))
                    chunk_tasks = tasks[chunk_start:chunk_end]
                    
                    future_to_index = {
                        executor.submit(self.structure_similarity, pattern1, pattern2): (i, j, idx) 
                        for idx, (i, j, pattern1, pattern2) in enumerate(chunk_tasks, chunk_start)
                    }
                    
                    for future in as_completed(future_to_index):
                        i, j, idx = future_to_index[future]
                        try:
                            sim = future.result()
                            similarity_matrix[i, j] = sim
                            similarity_matrix[j, i] = sim
                            results[idx] = True
                        except Exception as e:
                            print(f"Error calculating similarity for pair ({i},{j}): {e}")
                            similarity_matrix[i, j] = 0
                            similarity_matrix[j, i] = 0
                            results[idx] = False
                    
                    completed = sum(1 for r in results if r is not None)
                    print(f"  Progress: {completed}/{len(tasks)} ({completed/len(tasks)*100:.1f}%)")
        else:
            # 单线程计算，使用简单进度显示
            for idx, (i, j, pattern1, pattern2) in enumerate(tasks):
                sim = self.structure_similarity(pattern1, pattern2)
                similarity_matrix[i, j] = sim
                similarity_matrix[j, i] = sim
                
                if idx % 1000 == 0:
                    print(f"  Progress: {idx}/{len(tasks)} ({idx/len(tasks)*100:.1f}%)")
        
        return similarity_matrix
    
    def find_cluster_representatives_parallel(self, labels, unique_patterns, stems):
        """并行找到每个聚类的代表性模式 - 修复tkinter问题"""
        clusters = defaultdict(list)
        for pattern, label in zip(unique_patterns, labels):
            clusters[label].append(pattern)
        
        print(f"Finding representatives for {len(clusters)} clusters...")
        
        cluster_representatives = {}
        
        # 并行处理每个聚类
        if self.threads > 1 and len(clusters) > 10:
            with ThreadPoolExecutor(max_workers=min(self.threads, len(clusters))) as executor:
                future_to_cluster = {
                    executor.submit(self.find_single_cluster_representative, cluster_id, patterns, stems): cluster_id
                    for cluster_id, patterns in clusters.items()
                }
                
                completed = 0
                for future in as_completed(future_to_cluster):
                    cluster_id = future_to_cluster[future]
                    try:
                        representative = future.result()
                        cluster_representatives[cluster_id] = representative
                        completed += 1
                        if completed % 100 == 0:
                            print(f"  Progress: {completed}/{len(clusters)} clusters processed")
                    except Exception as e:
                        print(f"Error finding representative for cluster {cluster_id}: {e}")
                        # 选择第一个模式作为备选
                        if clusters[cluster_id]:
                            cluster_representatives[cluster_id] = clusters[cluster_id][0]
        else:
            # 单线程处理
            for cluster_id, patterns in clusters.items():
                representative = self.find_single_cluster_representative(cluster_id, patterns, stems)
                cluster_representatives[cluster_id] = representative
        
        return cluster_representatives
    
    def find_single_cluster_representative(self, cluster_id, patterns, stems):
        """找到单个聚类的代表性模式"""
        if not patterns:
            return None
        
        # 计算每个模式的频率
        pattern_counts = {}
        for pattern in patterns:
            count = sum(1 for stem in stems if stem['stem_structure'] == pattern)
            pattern_counts[pattern] = count
        
        # 选择最常见的模式作为代表
        representative = max(pattern_counts.items(), key=lambda x: x[1])[0]
        return representative
    
    def update_stems_with_clusters_parallel(self, stems, cluster_representatives, labels, unique_patterns):
        """并行更新stem为聚类代表模式 - 修复tkinter问题"""
        # 创建模式到聚类代表的映射
        pattern_to_representative = {}
        for pattern, label in zip(unique_patterns, labels):
            if label in cluster_representatives:
                pattern_to_representative[pattern] = cluster_representatives[label]
        
        print(f"Updating {len(stems)} stems with cluster representatives...")
        
        # 并行更新stems
        if self.threads > 1 and len(stems) > 1000:
            chunk_size = max(1000, len(stems) // (self.threads * 4))  # 更小的chunk size
            stem_chunks = [stems[i:i + chunk_size] for i in range(0, len(stems), chunk_size)]
            
            clustered_stems = []
            completed = 0
            
            with ThreadPoolExecutor(max_workers=self.threads) as executor:
                future_to_chunk = {
                    executor.submit(self.update_stem_chunk, chunk, pattern_to_representative): i
                    for i, chunk in enumerate(stem_chunks)
                }
                
                for future in as_completed(future_to_chunk):
                    try:
                        updated_chunk = future.result()
                        clustered_stems.extend(updated_chunk)
                        completed += 1
                        print(f"  Progress: {completed}/{len(stem_chunks)} chunks completed")
                    except Exception as e:
                        print(f"Error updating stem chunk: {e}")
                        # 使用原始stems作为备选
                        chunk_idx = future_to_chunk[future]
                        clustered_stems.extend(stem_chunks[chunk_idx])
        else:
            # 单线程更新
            clustered_stems = self.update_stem_chunk(stems, pattern_to_representative)
        
        return clustered_stems
    
    def update_stem_chunk(self, stem_chunk, pattern_to_representative):
        """更新一个stem块为聚类代表模式"""
        updated_stems = []
        for stem in stem_chunk:
            pattern = stem['stem_structure']
            if pattern in pattern_to_representative:
                new_stem = stem.copy()
                new_stem['stem_structure'] = pattern_to_representative[pattern]
                new_stem['original_pattern'] = pattern
                new_stem['cluster_id'] = hash(pattern_to_representative[pattern])  # 使用hash作为cluster id
                updated_stems.append(new_stem)
            else:
                updated_stems.append(stem)
        return updated_stems

    # def cluster_similar_motifs(self, stems, similarity_threshold=0.7):
        """对相似的stem结构进行聚类，提高counts"""
        if not stems:
            return stems
            
        print(f"Clustering {len(stems)} stems with similarity threshold {similarity_threshold}...")
        
        # 提取所有独特的结构模式
        unique_patterns = list(set(stem['stem_structure'] for stem in stems))
        
        if len(unique_patterns) <= 1:
            return stems  # 不需要聚类
            
        # 计算结构相似性矩阵
        similarity_matrix = self.calculate_structure_similarity(unique_patterns)
        
        # 使用DBSCAN进行聚类
        clustering = DBSCAN(eps=1-similarity_threshold, min_samples=1, metric='precomputed')
        labels = clustering.fit_predict(1 - similarity_matrix)  # 转换为距离矩阵
        
        # 创建聚类映射
        cluster_map = {}
        for pattern, label in zip(unique_patterns, labels):
            cluster_map[pattern] = label
        
        # 找到每个聚类的代表性模式（最常见的模式）
        cluster_representatives = {}
        for pattern, label in cluster_map.items():
            if label not in cluster_representatives:
                cluster_representatives[label] = pattern
            else:
                # 选择长度更接近平均值的模式作为代表
                current_rep = cluster_representatives[label]
                if abs(len(pattern) - len(current_rep)) < 2:  # 长度相近时选择更常见的
                    pattern_count = sum(1 for stem in stems if stem['stem_structure'] == pattern)
                    current_count = sum(1 for stem in stems if stem['stem_structure'] == current_rep)
                    if pattern_count > current_count:
                        cluster_representatives[label] = pattern
        
        # 更新stem的结构模式为聚类代表模式
        clustered_stems = []
        for stem in stems:
            pattern = stem['stem_structure']
            cluster_label = cluster_map[pattern]
            representative = cluster_representatives[cluster_label]
            
            # 创建新的stem对象，使用代表性模式
            new_stem = stem.copy()
            new_stem['stem_structure'] = representative
            new_stem['original_pattern'] = pattern  # 保留原始模式用于调试
            new_stem['cluster_id'] = cluster_label
            clustered_stems.append(new_stem)
        
        # 统计聚类效果
        original_count = len(unique_patterns)
        clustered_count = len(cluster_representatives)
        print(f"Clustering reduced {original_count} unique patterns to {clustered_count} clusters")
        
        return clustered_stems
    
    def calculate_structure_similarity(self, patterns):
        """计算结构模式之间的相似性"""
        n = len(patterns)
        similarity_matrix = np.eye(n)  # 对角线为1
        
        for i in range(n):
            for j in range(i+1, n):
                # 使用编辑距离计算相似性，考虑结构特征
                sim = self.structure_similarity(patterns[i], patterns[j])
                similarity_matrix[i, j] = sim
                similarity_matrix[j, i] = sim
        
        return similarity_matrix
    
    def structure_similarity(self, pattern1, pattern2):
        """计算两个结构模式的相似性得分"""
        # 1. 编辑距离相似性
        edit_sim = 1 - (Levenshtein.distance(pattern1, pattern2) / max(len(pattern1), len(pattern2)))
        
        # 2. 结构特征相似性
        struct_sim = self.structural_feature_similarity(pattern1, pattern2)
        
        # 3. 加权综合相似性
        combined_sim = 0.6 * edit_sim + 0.4 * struct_sim
        
        return combined_sim
    
    def structural_feature_similarity(self, pattern1, pattern2):
        """基于结构特征的相似性计算"""
        features1 = self.extract_structural_features(pattern1)
        features2 = self.extract_structural_features(pattern2)
        
        if not features1 or not features2:
            return 0.0
        
        # 计算特征向量相似性
        common_features = set(features1.keys()) | set(features2.keys())
        dot_product = sum(features1.get(f, 0) * features2.get(f, 0) for f in common_features)
        norm1 = np.sqrt(sum(v**2 for v in features1.values()))
        norm2 = np.sqrt(sum(v**2 for v in features2.values()))
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        return dot_product / (norm1 * norm2)
    
    def extract_structural_features(self, pattern):
        """提取结构特征向量"""
        features = {}
        
        # 基本结构特征
        features['length'] = len(pattern) / 50.0  # 归一化
        features['paired_ratio'] = (pattern.count('(') + pattern.count(')')) / len(pattern)
        features['unpaired_ratio'] = pattern.count('.') / len(pattern)
        features['stem_count'] = pattern.count('(') / 10.0  # 归一化
        
        # 结构连续性特征
        features['max_stem_run'] = self.max_consecutive_count(pattern, '(') / 10.0
        features['max_loop_run'] = self.max_consecutive_count(pattern, '.') / 10.0
        
        # 对称性特征
        features['symmetry'] = self.calculate_symmetry(pattern)
        
        return features
    
    def max_consecutive_count(self, pattern, char):
        """计算字符连续出现的最大次数"""
        max_count = 0
        current_count = 0
        
        for c in pattern:
            if c == char:
                current_count += 1
                max_count = max(max_count, current_count)
            else:
                current_count = 0
        
        return max_count
    
    def calculate_symmetry(self, pattern):
        """计算结构对称性"""
        # 简单的对称性度量：检查左右括号是否平衡
        stack = []
        for char in pattern:
            if char == '(':
                stack.append(char)
            elif char == ')':
                if stack:
                    stack.pop()
        
        balance_score = 1.0 if len(stack) == 0 else 0.5  # 简化对称性评分
        return balance_score


    def extract_stems_parallel(self, sequences, structures, desc="Extracting stems"):
        """使用多进程并行提取stem-loops"""
        # 总是使用多进程，不受GIL限制
        return self._extract_stems_parallel_multiprocess(sequences, structures, desc)

    def _extract_stems_parallel_multiprocess(self, sequences, structures, desc):
        """使用多进程并行提取stem-loops - 不受GIL限制"""
        all_stems = []
        
        # 准备任务参数
        tasks = list(zip(sequences, structures))
        
        print(f"{desc} using {self.threads} process(es)...")
        print(f"Total sequences: {len(tasks)}")
        
        # 使用多进程
        with ProcessPoolExecutor(max_workers=self.threads) as executor:
            # 使用偏函数传递参数
            from functools import partial
            extract_func = partial(extract_stems_single_standalone, 
                                min_stem_length=self.min_stem_length,
                                max_stem_length=self.max_stem_length)
            
            # 提交任务
            futures = [executor.submit(extract_func, seq, struct) for seq, struct in tasks]
            
            # 收集结果
            completed = 0
            try:
                from tqdm import tqdm
                with tqdm(total=len(tasks), desc=desc) as pbar:
                    for future in futures:
                        try:
                            stems = future.result()
                            all_stems.extend(stems)
                        except Exception as e:
                            print(f"Stem extraction failed: {e}")
                        finally:
                            completed += 1
                            pbar.update(1)
            except ImportError:
                # 没有tqdm的回退方案
                for i, future in enumerate(futures):
                    try:
                        stems = future.result()
                        all_stems.extend(stems)
                    except Exception as e:
                        print(f"Stem extraction failed for sequence {i}: {e}")
                    if i % 1000 == 0:
                        print(f"  Processed {i}/{len(tasks)} sequences")
        
        print(f"  Extracted {len(all_stems)} stems from {len(sequences)} sequences")
        return all_stems

    def run(self, positive_file, negative_file):
        """运行完整的模式挖掘流程 - 这是pipeline.py调用的方法"""
        print("Starting stem-loop pattern mining...")
        
        try:
            # 加载数据
            pos_sequences, pos_structures, neg_sequences, neg_structures = self.parse_structure_files(
                positive_file, negative_file
            )
            
            if not pos_sequences or not neg_sequences:
                print("Error: No sequences loaded from input files!")
                return None, None, None
            
            print(f"Loaded {len(pos_sequences)} positive and {len(neg_sequences)} negative sequences")
            
            # 挖掘模式
            pattern_analysis, positive_stems, negative_stems = self.mine_patterns(
                pos_sequences, pos_structures, neg_sequences, neg_structures
            )
            
            # 分析特征差异
            characteristic_stats = self.analyze_stem_characteristics(positive_stems, negative_stems)
            
            # 识别共识motif
            motifs = self.identify_consensus_motifs(pattern_analysis)
            
            # 创建可视化
            self.create_pattern_visualization(pattern_analysis, characteristic_stats)
            
            # 保存结果
            self.save_results(pattern_analysis, characteristic_stats, motifs)
            
            # 输出关键发现
            print("\n" + "="*60)
            print("KEY DISCOVERIES")
            print("="*60)
            
            if pattern_analysis:
                top_pattern = pattern_analysis[0]
                print(f"Top significant stem-loop pattern:")
                print(f"  Pattern: {top_pattern['pattern']}")
                print(f"  Enrichment: {top_pattern['enrichment']:.2f}x")
                print(f"  Adjusted p-value: {top_pattern['adjusted_p_value']:.6f}")
                print(f"  Positive frequency: {top_pattern['positive_frequency']:.4f}")
                print(f"  Negative frequency: {top_pattern['negative_frequency']:.4f}")
                print(f"  Positive count: {top_pattern['positive_count']}")
                print(f"  Negative count: {top_pattern['negative_count']}")
            else:
                print("No significant patterns found")
            
            if characteristic_stats:
                print(f"\nSignificant stem characteristics (p < 0.05):")
                significant_found = False
                for char, stats in characteristic_stats.items():
                    if stats['p_value'] < 0.05:
                        significant_found = True
                        direction = "higher" if stats['cohens_d'] > 0 else "lower"
                        print(f"  {char}: {direction} in positives (d = {stats['cohens_d']:.3f}, p = {stats['p_value']:.4f})")
                
                if not significant_found:
                    print("  No significant characteristics found")
            else:
                print("No characteristic statistics available")
            
            if pattern_analysis:
                print(f"\nSignificance summary:")
                print(f"  Total patterns analyzed: {len(pattern_analysis)}")
                print(f"  Significant patterns (p < 0.05): {sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.05)}")
                print(f"  Highly significant patterns (p < 0.01): {sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.01)}")
                print(f"  Very highly significant patterns (p < 0.001): {sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.001)}")
            
            print("\nPattern mining completed successfully!")
            return pattern_analysis, positive_stems, negative_stems
            
        except Exception as e:
            print(f"Error during pattern mining: {e}")
            import traceback
            traceback.print_exc()
            return None, None, None
        
    def parse_structure_files(self, positive_file, negative_file):
        """解析结构文件，提取序列和结构"""
        def parse_file(filename):
            sequences = []
            structures = []
            with open(filename, 'r') as f:
                current_seq = ""
                current_struct = ""
                for line in f:
                    line = line.strip()
                    if line.startswith('>'):
                        if current_seq and current_struct:
                            sequences.append(current_seq)
                            structures.append(current_struct)
                        current_seq = ""
                        current_struct = ""
                    elif line and not line.startswith('>'):
                        if all(c in 'ACGUacgu' for c in line):
                            current_seq = line.upper()
                        elif any(c in '().' for c in line):
                            current_struct = line
                # 添加最后一个序列
                if current_seq and current_struct:
                    sequences.append(current_seq)
                    structures.append(current_struct)
            return sequences, structures
        
        print("Loading positive sequences...")
        pos_sequences, pos_structures = parse_file(positive_file)
        print("Loading negative sequences...")
        neg_sequences, neg_structures = parse_file(negative_file)
        
        print(f"Positive: {len(pos_sequences)} sequences, {len(pos_structures)} structures")
        print(f"Negative: {len(neg_sequences)} sequences, {len(neg_structures)} structures")
        
        return pos_sequences, pos_structures, neg_sequences, neg_structures
    
    def extract_detailed_stem_loops(self, sequence, structure, min_stem_length=1):
        """详细提取stem-loop结构特征"""
        stems = []
        stack = []
        
        # 验证序列和结构长度一致
        if len(sequence) != len(structure):
            print(f"Warning: Sequence length ({len(sequence)}) != Structure length ({len(structure)})")
            return stems
        
        for i, char in enumerate(structure):
            if char == '(':
                stack.append(i)
            elif char == ')':
                if stack:
                    start = stack.pop()
                    # 计算完整的stem长度
                    stem_length = 0
                    j = 0
                    while (start + j < len(structure) and 
                        i - j >= 0 and 
                        structure[start + j] == '(' and 
                        structure[i - j] == ')'):
                        stem_length += 1
                        j += 1
                    
                    if stem_length >= min_stem_length:  # 使用参数化的最小stem长度
                        stem_end_left = start + stem_length - 1
                        stem_start_right = i - stem_length + 1
                        
                        # 验证索引范围
                        if (start + stem_length > len(sequence) or 
                            stem_start_right < 0 or 
                            i >= len(sequence)):
                            continue
                            
                        # 提取序列信息
                        left_arm = sequence[start:start+stem_length]
                        right_arm = sequence[stem_start_right:i+1]
                        loop_start = stem_end_left + 1
                        loop_end = stem_start_right - 1
                        
                        loop_seq = ""
                        loop_struct = ""
                        if loop_start <= loop_end and loop_end < len(sequence):
                            loop_seq = sequence[loop_start:loop_end+1]
                            loop_struct = structure[loop_start:loop_end+1]
                        
                        # 计算GC含量
                        stem_seq = left_arm + right_arm
                        stem_gc = stem_seq.count('G') + stem_seq.count('C')
                        stem_gc_content = stem_gc / len(stem_seq) if stem_seq else 0
                        
                        # 计算环的特征
                        loop_length = len(loop_seq)
                        loop_unpaired = loop_struct.count('.') if loop_struct else 0
                        loop_unpaired_ratio = loop_unpaired / loop_length if loop_length > 0 else 0
                        
                        stem_info = {
                            'stem_structure': structure[start:i+1],
                            'left_arm': left_arm,
                            'right_arm': right_arm,
                            'loop_sequence': loop_seq,
                            'loop_structure': loop_struct,
                            'stem_length': stem_length,
                            'loop_length': loop_length,
                            'stem_gc_content': stem_gc_content,
                            'loop_unpaired_ratio': loop_unpaired_ratio,
                            'start_position': start,
                            'end_position': i,
                            'total_length': i - start + 1
                        }
                        stems.append(stem_info)
        
        return stems
    
    def calculate_stem_energy(self, left_arm, right_arm, loop_seq):
        """计算stem-loop的自由能"""
        try:
            # 构建完整的stem-loop序列
            full_sequence = left_arm + loop_seq + right_arm
            # 限制序列长度以避免计算问题
            if len(full_sequence) > 100:
                return 0.0
            (structure, energy) = RNA.fold(full_sequence)
            return float(energy)
        except Exception as e:
            print(f"Energy calculation failed: {e}")
            return 0.0

    def extract_stems_single(self, args):
        """为单个序列提取stem-loops - 用于多线程"""
        seq, struct, min_stem_length = args
        try:
            stems = self.extract_detailed_stem_loops(seq, struct, min_stem_length)
            for stem in stems:
                stem['energy'] = self.calculate_stem_energy(
                    stem['left_arm'], stem['right_arm'], stem['loop_sequence']
                )
            return stems
        except Exception as e:
            print(f"Error extracting stems for sequence: {e}")
            return []
    
    def _extract_stems_sequential(self, sequences, structures, desc):
        """单线程提取stem-loops"""
        try:
            from tqdm import tqdm
            use_tqdm = True
        except ImportError:
            use_tqdm = False
            print("tqdm not available, using simple progress reporting")
        
        all_stems = []
        
        if use_tqdm:
            # 使用 tqdm 进度条
            for i, (seq, struct) in enumerate(tqdm(zip(sequences, structures), 
                                                total=len(sequences), 
                                                desc=desc)):
                stems = self.extract_detailed_stem_loops(seq, struct, self.min_stem_length)
                for stem in stems:
                    stem['energy'] = self.calculate_stem_energy(
                        stem['left_arm'], stem['right_arm'], stem['loop_sequence']
                    )
                all_stems.extend(stems)
        else:
            # 回退到原来的简单进度显示
            for i, (seq, struct) in enumerate(zip(sequences, structures)):
                if i > 0 and i % 1000 == 0:
                    print(f"  Processed {i}/{len(sequences)} sequences")
                stems = self.extract_detailed_stem_loops(seq, struct, self.min_stem_length)
                for stem in stems:
                    stem['energy'] = self.calculate_stem_energy(
                        stem['left_arm'], stem['right_arm'], stem['loop_sequence']
                    )
                all_stems.extend(stems)
        
        return all_stems
    
    def _extract_stems_parallel_executor(self, sequences, structures, desc):
        """使用线程池并行提取stem-loops"""
        all_stems = []
        
        # 准备任务参数
        tasks = [(seq, struct, self.min_stem_length) for seq, struct in zip(sequences, structures)]
        
        print(f"{desc} using {self.threads} thread(s)...")
        
        with ThreadPoolExecutor(max_workers=self.threads) as executor:
            # 提交所有任务
            future_to_index = {
                executor.submit(self.extract_stems_single, task): i 
                for i, task in enumerate(tasks)
            }
            
            # 处理完成的任务
            completed = 0
            try:
                from tqdm import tqdm
                use_tqdm = True
            except ImportError:
                use_tqdm = False
            
            if use_tqdm:
                with tqdm(total=len(tasks), desc=desc) as pbar:
                    for future in as_completed(future_to_index):
                        index = future_to_index[future]
                        try:
                            stems = future.result()
                            all_stems.extend(stems)
                        except Exception as e:
                            print(f"Stem extraction failed for sequence {index}: {e}")
                        finally:
                            completed += 1
                            pbar.update(1)
            else:
                # 简单进度显示
                for future in as_completed(future_to_index):
                    index = future_to_index[future]
                    try:
                        stems = future.result()
                        all_stems.extend(stems)
                    except Exception as e:
                        print(f"Stem extraction failed for sequence {index}: {e}")
                    finally:
                        completed += 1
                        if completed % 1000 == 0:
                            print(f"  Processed {completed}/{len(tasks)} sequences")
        
        print(f"  Extracted {len(all_stems)} stems from {len(sequences)} sequences")
        return all_stems

    def mine_patterns(self, pos_sequences, pos_structures, neg_sequences, neg_structures):
        """挖掘显著的stem-loop模式 - 使用改进的统计方法"""
        print("Extracting stem-loops from positive sequences...")
        
        # 使用并行方法提取阳性stem-loops
        positive_stems = self.extract_stems_parallel(
            pos_sequences, pos_structures, "Positive sequences"
        )
        
        print("Extracting stem-loops from negative sequences...")
        # 使用并行方法提取阴性stem-loops
        negative_stems = self.extract_stems_parallel(
            neg_sequences, neg_structures, "Negative sequences"
        )
        
        # 保存stem数据供Step4使用
        self.positive_stems = positive_stems
        self.negative_stems = negative_stems
        
        print(f"Total positive stems: {len(positive_stems)}")
        print(f"Total negative stems: {len(negative_stems)}")
        
        # 应用stem分割
        if self.max_stem_length > 0:
            positive_stems = self.split_long_stems(positive_stems)
            negative_stems = self.split_long_stems(negative_stems)
            print(f"After splitting - Positive: {len(positive_stems)}, Negative: {len(negative_stems)}")
        
        # 使用改进的模式频率分析方法
        return self._analyze_pattern_frequency_improved(positive_stems, negative_stems)

    def save_results(self, pattern_analysis, characteristic_stats, motifs, output_prefix='stem_patterns'):
        """保存分析结果 - 添加stem数据保存"""
        # 使用输出目录
        output_prefix = str(self.output_dir / output_prefix)    
        
        # 保存模式分析结果
        if pattern_analysis:
            pattern_df = pd.DataFrame(pattern_analysis)
            pattern_df.to_csv(f'{output_prefix}_analysis.csv', index=False)
            print(f"- Pattern analysis: {output_prefix}_analysis.csv")
        else:
            print("- No pattern analysis data to save")
        
        # 保存stem数据供Step4使用 - 修复：确保保存stem数据
        if hasattr(self, 'positive_stems') and self.positive_stems:
            import joblib
            positive_stem_file = f'{output_prefix}_positive_stems.pkl'
            joblib.dump(self.positive_stems, positive_stem_file)
            print(f"- Positive stems: {positive_stem_file}")
        
        if hasattr(self, 'negative_stems') and self.negative_stems:
            import joblib
            negative_stem_file = f'{output_prefix}_negative_stems.pkl'
            joblib.dump(self.negative_stems, negative_stem_file)
            print(f"- Negative stems: {negative_stem_file}")
        
        # 保存特征统计
        if characteristic_stats:
            char_df = pd.DataFrame(characteristic_stats).T
            char_df.to_csv(f'{output_prefix}_characteristics.csv')
            print(f"- Characteristics: {output_prefix}_characteristics.csv")
        else:
            print("- No characteristic data to save")
        
        # 保存motif分析
        if motifs:
            motif_df = pd.DataFrame(motifs)
            # 过滤正富集的motif
            positive_enriched_motifs = motif_df[
                (motif_df['enrichment'] > 1) | (motif_df['enrichment'] == float('inf'))
            ]
            if not positive_enriched_motifs.empty:
                positive_enriched_motifs.to_csv(f'{output_prefix}_motifs.csv', index=False)
                print(f"- Motifs: {output_prefix}_motifs.csv")
            else:
                print("- No positively enriched motifs found")
        else:
            print("- No motif data to save")
        
        # 保存总结报告
        summary = {
            'total_patterns_analyzed': len(pattern_analysis) if pattern_analysis else 0,
            'significant_patterns_p05': sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.05) if pattern_analysis else 0,
            'significant_patterns_p01': sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.01) if pattern_analysis else 0,
            'significant_patterns_p001': sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.001) if pattern_analysis else 0,
            'top_pattern': pattern_analysis[0]['pattern'] if pattern_analysis else None,
            'max_enrichment': max(p['enrichment'] for p in pattern_analysis) if pattern_analysis else 0,
            'positive_stems_count': len(self.positive_stems) if hasattr(self, 'positive_stems') else 0,
            'negative_stems_count': len(self.negative_stems) if hasattr(self, 'negative_stems') else 0,
            'analysis_timestamp': pd.Timestamp.now().isoformat()
        }
        
        with open(f'{output_prefix}_summary.json', 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"- Summary: {output_prefix}_summary.json")
        print(f"\nResults saved to '{self.output_dir}' directory")

    def _analyze_pattern_frequency_improved(self, positive_stems, negative_stems):
        """改进的模式频率分析 - 保留所有motif"""
        # 按结构模式分组
        positive_patterns = Counter([stem['stem_structure'] for stem in positive_stems])
        negative_patterns = Counter([stem['stem_structure'] for stem in negative_stems])
        
        # 关键修改：收集所有出现的模式，不进行频率过滤
        all_patterns = set(positive_patterns.keys()) | set(negative_patterns.keys())
        
        print(f"Analyzing {len(all_patterns)} patterns")
        
        pattern_analysis = []
        
        total_pos = len(positive_stems)
        total_neg = len(negative_stems)
        
        for pattern in all_patterns:
            pos_count = positive_patterns.get(pattern, 0)
            neg_count = negative_patterns.get(pattern, 0)
            
            if total_pos > 0 and total_neg > 0:
                pos_freq = pos_count / total_pos
                neg_freq = neg_count / total_neg
                
                # 计算富集度
                if neg_freq > 0:
                    enrichment = pos_freq / neg_freq
                else:
                    enrichment = float('inf') if pos_freq > 0 else 1.0
                
                # 改进的Fisher精确检验 - 使用连续性校正
                try:
                    # 添加连续性校正（每个单元格加0.5）
                    contingency_table = [
                        [pos_count + 0.5, total_pos - pos_count + 0.5],
                        [neg_count + 0.5, total_neg - neg_count + 0.5]
                    ]
                    _, p_value_two_sided = stats.fisher_exact(contingency_table, alternative='two-sided')
                    
                    # 手动计算单侧p值（更敏感）
                    if pos_freq > neg_freq:
                        p_value = p_value_two_sided / 2  # 单侧检验
                    else:
                        p_value = 1 - p_value_two_sided / 2
                        
                except:
                    p_value = 1.0
                
                # 计算效应量 - 相对风险和绝对风险差
                if neg_freq > 0:
                    relative_risk = pos_freq / neg_freq
                    risk_difference = pos_freq - neg_freq
                else:
                    relative_risk = float('inf')
                    risk_difference = pos_freq
                
                # 计算出现比例
                occurrence_ratio = pos_count / (pos_count + neg_count) if (pos_count + neg_count) > 0 else 0
                
                pattern_analysis.append({
                    'pattern': pattern,
                    'positive_count': pos_count,
                    'negative_count': neg_count,
                    'positive_frequency': pos_freq,
                    'negative_frequency': neg_freq,
                    'enrichment': enrichment,
                    'p_value': p_value,
                    'relative_risk': relative_risk,
                    'risk_difference': risk_difference,
                    'occurrence_ratio': occurrence_ratio,
                    'total_occurrence': pos_count + neg_count
                })
        
        # 多重检验校正 - 使用更稳健的方法
        if pattern_analysis:
            p_values = [p['p_value'] for p in pattern_analysis]
            
            # 方法1: FDR Benjamini-Yekutieli（对依赖关系更稳健）
            _, corrected_pvals_by, _, _ = multipletests(p_values, method='fdr_by')
            
            # 方法2: Bonferroni（更严格）
            _, corrected_pvals_bonf, _, _ = multipletests(p_values, method='bonferroni')
            
            # 选择两者中较小的p值（更敏感）
            for i, pval in enumerate(corrected_pvals_by):
                pattern_analysis[i]['adjusted_p_value'] = min(corrected_pvals_by[i], corrected_pvals_bonf[i])
                pattern_analysis[i]['adjusted_p_value_by'] = corrected_pvals_by[i]
                pattern_analysis[i]['adjusted_p_value_bonf'] = corrected_pvals_bonf[i]
            
            # 排序：先按调整后p值，再按富集度，再按counts
            pattern_analysis.sort(key=lambda x: (x['adjusted_p_value'], -x['enrichment'], -x['positive_count']))
            
            # 保留原有的输出统计信息 - 这是重要的！
            top_patterns = pattern_analysis[:10]
            print(f"\nTop 10 significant patterns:")
            for i, pattern in enumerate(top_patterns):
                print(f"  {i+1}. {pattern['pattern']}: "
                    f"enrichment={pattern['enrichment']:.2f}x, "
                    f"p={pattern['adjusted_p_value']:.2e}, "
                    f"pos_count={pattern['positive_count']}")
            
            # 添加额外的统计信息
            significant_count = sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.05)
            print(f"\nStatistical summary:")
            print(f"  Total patterns analyzed: {len(pattern_analysis)}")
            print(f"  Significant patterns (p < 0.05): {significant_count}")
            print(f"  Highly significant (p < 0.01): {sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.01)}")
            print(f"  Very highly significant (p < 0.001): {sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.001)}")
            
        else:
            print("Warning: No patterns found for analysis")
        
        return pattern_analysis, positive_stems, negative_stems
    
    def analyze_stem_characteristics(self, positive_stems, negative_stems):
        """分析stem的结构特征差异"""
        characteristics = [
            'stem_length', 'loop_length', 'stem_gc_content', 
            'loop_unpaired_ratio', 'energy', 'total_length'
        ]
        
        characteristic_stats = {}
        
        print("Analyzing stem characteristics...")
        
        for char in characteristics:
            pos_values = [stem[char] for stem in positive_stems if char in stem and not np.isnan(stem[char])]
            neg_values = [stem[char] for stem in negative_stems if char in stem and not np.isnan(stem[char])]
            
            if len(pos_values) > 5 and len(neg_values) > 5:  # 确保有足够样本
                try:
                    t_stat, p_value = stats.ttest_ind(pos_values, neg_values, equal_var=False)
                    
                    # 效应量
                    pos_mean, neg_mean = np.mean(pos_values), np.mean(neg_values)
                    pos_std, neg_std = np.std(pos_values, ddof=1), np.std(neg_values, ddof=1)
                    pooled_std = np.sqrt((pos_std**2 + neg_std**2) / 2)
                    cohens_d = (pos_mean - neg_mean) / pooled_std if pooled_std != 0 else 0
                    
                    characteristic_stats[char] = {
                        'positive_mean': pos_mean,
                        'negative_mean': neg_mean,
                        'positive_std': pos_std,
                        'negative_std': neg_std,
                        'mean_difference': pos_mean - neg_mean,
                        'cohens_d': cohens_d,
                        'p_value': p_value,
                        'positive_samples': len(pos_values),
                        'negative_samples': len(neg_values)
                    }
                except Exception as e:
                    print(f"Error analyzing {char}: {e}")
                    characteristic_stats[char] = {
                        'positive_mean': np.nan,
                        'negative_mean': np.nan,
                        'mean_difference': np.nan,
                        'cohens_d': np.nan,
                        'p_value': 1.0
                    }
        
        return characteristic_stats
    
    def split_long_stems(self, stems, max_stem_length=None):
        """将长stem分割成较短的stem - 优化版本"""
        if max_stem_length is None:
            max_stem_length = self.max_stem_length
            
        print(f"Splitting long stems (max_length={max_stem_length})...")
        
        # 如果没有需要分割的stem，直接返回
        if not stems:
            return stems
        
        # 检查是否有需要分割的stem
        long_stems_count = sum(1 for stem in stems if stem['stem_length'] > max_stem_length)
        if long_stems_count == 0:
            print(f"No stems need splitting (all stems <= {max_stem_length})")
            return stems
        
        print(f"Found {long_stems_count} stems that need splitting")
        
        # 根据数据量选择处理方式
        if len(stems) > 5000 and self.threads > 1:
            return self._split_long_stems_multiprocess(stems, max_stem_length)
        else:
            return self._split_long_stems_sequential_optimized(stems, max_stem_length)

    def _split_long_stems_multiprocess(self, stems, max_stem_length):
        """多进程版本 - 使用进程池，修复性能问题"""
        print(f"Using multiprocessing for stem splitting with {self.threads} processes...")
        
        # 如果stem数量较少，使用单线程
        if len(stems) < 1000:
            return self._split_long_stems_sequential_optimized(stems, max_stem_length)
        
        # 准备任务 - 只传递必要数据，避免传递整个对象
        tasks = []
        for stem in stems:
            # 只提取必要字段，避免传递复杂对象
            simple_stem = {
                'stem_structure': stem['stem_structure'],
                'left_arm': stem['left_arm'],
                'right_arm': stem['right_arm'],
                'loop_sequence': stem['loop_sequence'],
                'loop_structure': stem['loop_structure'],
                'stem_length': stem['stem_length'],
                'loop_length': stem['loop_length'],
                'stem_gc_content': stem['stem_gc_content'],
                'loop_unpaired_ratio': stem['loop_unpaired_ratio'],
                'start_position': stem['start_position'],
                'end_position': stem['end_position'],
                'total_length': stem['total_length']
            }
            tasks.append((simple_stem, max_stem_length))
        
        new_stems = []
        
        # 使用进程池，设置合适的chunk_size
        chunk_size = max(100, len(tasks) // (self.threads * 4))
        
        with ProcessPoolExecutor(max_workers=self.threads) as executor:
            # 使用map而不是submit，提高效率
            try:
                results = executor.map(
                    self._process_single_stem_simple, 
                    [task[0] for task in tasks],
                    [task[1] for task in tasks],
                    chunksize=chunk_size
                )
                
                # 收集结果 - 使用进度条
                try:
                    from tqdm import tqdm
                    # 使用tqdm进度条
                    for result_stems in tqdm(results, total=len(tasks), desc="Splitting stems (multiprocess)", unit="stem"):
                        new_stems.extend(result_stems)
                except ImportError:
                    # tqdm不可用时的回退方案
                    completed = 0
                    for result_stems in results:
                        new_stems.extend(result_stems)
                        completed += 1
                        if completed % 1000 == 0:
                            print(f"  Processed {completed}/{len(tasks)} stems")
                            
            except Exception as e:
                print(f"Multiprocessing error: {e}, falling back to sequential processing")
                return self._split_long_stems_sequential_optimized(stems, max_stem_length)
        
        print(f"Stem splitting: {len(stems)} → {len(new_stems)} stems (max_length={max_stem_length})")
        return new_stems

    def _process_single_stem_simple(self, simple_stem, max_stem_length):
        """处理单个stem的分割 - 简化版本，避免复杂计算"""
        stem_length = simple_stem['stem_length']
        result_stems = []
        
        # 如果stem长度超过阈值，进行分割
        if stem_length > max_stem_length:
            # 计算可以分割的数量
            num_splits = stem_length // max_stem_length
            if stem_length % max_stem_length > 0:
                num_splits += 1
            
            for j in range(num_splits):
                start_idx = j * max_stem_length
                end_idx = min((j + 1) * max_stem_length, stem_length)
                
                # 创建新的stem对象 - 不计算能量
                new_stem = self._create_sub_stem_simple(simple_stem, start_idx, end_idx)
                if new_stem:
                    result_stems.append(new_stem)
        else:
            # 不需要分割，直接保留原stem
            result_stems.append(simple_stem)
        
        return result_stems

    def _create_sub_stem_simple(self, original_stem, start_idx, end_idx):
        """从原stem创建子stem - 简化版本，不计算能量"""
        try:
            stem_length = end_idx - start_idx
            
            # 提取子stem的左右臂
            left_arm = original_stem['left_arm'][start_idx:end_idx]
            right_arm = original_stem['right_arm'][-(end_idx):len(original_stem['right_arm']) - start_idx]

            # 构建子stem的结构表示
            stem_structure = '(' * stem_length + original_stem['loop_structure'] + ')' * stem_length
            
            # 计算新的位置信息
            total_length = stem_length * 2 + original_stem['loop_length']
            start_position = original_stem['start_position'] + start_idx
            end_position = start_position + total_length - 1
            
            new_stem = {
                'stem_structure': stem_structure,
                'left_arm': left_arm,
                'right_arm': right_arm,
                'loop_sequence': original_stem['loop_sequence'],
                'loop_structure': original_stem['loop_structure'],
                'stem_length': stem_length,
                'loop_length': original_stem['loop_length'],
                'stem_gc_content': original_stem['stem_gc_content'],
                'loop_unpaired_ratio': original_stem['loop_unpaired_ratio'],
                'start_position': start_position,
                'end_position': end_position,
                'total_length': total_length,
                'energy': 0.0  # 设置为0，后续需要时再计算
            }
            return new_stem
        except Exception as e:
            print(f"Error creating sub-stem: {e}")
            return None


    def _split_long_stems_sequential_optimized(self, stems, max_stem_length):
        """优化的单线程版本"""
        new_stems = []
        long_stem_count = 0
        
        try:
            from tqdm import tqdm
            stem_iterator = tqdm(stems, desc="Splitting stems", unit="stem")
            use_tqdm = True
        except ImportError:
            stem_iterator = stems
            use_tqdm = False
            print("tqdm not available, using simple progress reporting")
        
        for i, stem in enumerate(stem_iterator):
            stem_length = stem['stem_length']
            
            # 如果stem长度超过阈值，进行分割
            if stem_length > max_stem_length:
                long_stem_count += 1
                # 计算可以分割的数量
                num_splits = stem_length // max_stem_length
                if stem_length % max_stem_length > 0:
                    num_splits += 1
                
                for j in range(num_splits):
                    start_idx = j * max_stem_length
                    end_idx = min((j + 1) * max_stem_length, stem_length)
                    
                    # 创建新的stem对象 - 不计算能量
                    new_stem = self._create_sub_stem_simple(stem, start_idx, end_idx)
                    if new_stem:
                        new_stems.append(new_stem)
            else:
                # 不需要分割，直接保留原stem
                new_stems.append(stem)
            
            # 如果没有使用tqdm，显示简单进度
            if not use_tqdm and i > 0 and i % 5000 == 0:
                print(f"  Processed {i}/{len(stems)} stems, found {long_stem_count} long stems")
        
        print(f"Stem splitting: {len(stems)} → {len(new_stems)} stems ({long_stem_count} long stems split)")
        return new_stems

    def _split_long_stems_parallel(self, stems, max_stem_length):
        """多线程版本 - 带进度条"""
        print(f"Using {self.threads} threads for stem splitting...")
        
        # 准备任务
        tasks = []
        for stem in stems:
            tasks.append((stem, max_stem_length))
        
        new_stems = []
        
        # 使用线程池
        with ThreadPoolExecutor(max_workers=self.threads) as executor:
            # 提交所有任务
            future_to_stem = {
                executor.submit(self._process_single_stem, stem, max_stem_length): i 
                for i, (stem, max_stem_length) in enumerate(tasks)
            }
            
            # 处理完成的任务
            completed = 0
            try:
                from tqdm import tqdm
                use_tqdm = True
            except ImportError:
                use_tqdm = False
            
            if use_tqdm:
                with tqdm(total=len(tasks), desc="Splitting stems") as pbar:
                    for future in as_completed(future_to_stem):
                        try:
                            result_stems = future.result()
                            new_stems.extend(result_stems)
                        except Exception as e:
                            print(f"Stem splitting failed: {e}")
                            # 如果失败，使用原始stem
                            idx = future_to_stem[future]
                            new_stems.append(stems[idx])
                        finally:
                            completed += 1
                            pbar.update(1)
            else:
                # 简单进度显示
                for future in as_completed(future_to_stem):
                    try:
                        result_stems = future.result()
                        new_stems.extend(result_stems)
                    except Exception as e:
                        print(f"Stem splitting failed: {e}")
                        idx = future_to_stem[future]
                        new_stems.append(stems[idx])
                    finally:
                        completed += 1
                        if completed % 1000 == 0:
                            print(f"  Processed {completed}/{len(tasks)} stems")
        
        print(f"Stem splitting: {len(stems)} → {len(new_stems)} stems (max_length={max_stem_length})")
        return new_stems

    def _process_single_stem(self, stem, max_stem_length):
        """处理单个stem的分割 - 用于多线程"""
        stem_length = stem['stem_length']
        result_stems = []
        
        # 如果stem长度超过阈值，进行分割
        if stem_length > max_stem_length:
            # 计算可以分割的数量
            num_splits = stem_length // max_stem_length
            if stem_length % max_stem_length > 0:
                num_splits += 1
            
            for j in range(num_splits):
                start_idx = j * max_stem_length
                end_idx = min((j + 1) * max_stem_length, stem_length)
                
                # 创建新的stem对象
                new_stem = self._create_sub_stem(stem, start_idx, end_idx)
                if new_stem:
                    result_stems.append(new_stem)
        else:
            # 不需要分割，直接保留原stem
            result_stems.append(stem)
        
        return result_stems
    
    def _create_sub_stem(self, original_stem, start_idx, end_idx):
        """从原stem创建子stem"""
        try:
            stem_length = end_idx - start_idx
            
            # 提取子stem的左右臂
            left_arm = original_stem['left_arm'][start_idx:end_idx]
            # right_arm = original_stem['right_arm'][-(end_idx):-(start_idx) if start_idx > 0 else None]
            right_arm = original_stem['right_arm'][-(end_idx):len(original_stem['right_arm']) - start_idx]

            # 构建子stem的结构表示
            stem_structure = '(' * stem_length + original_stem['loop_structure'] + ')' * stem_length
            
            # 计算新的位置信息
            total_length = stem_length * 2 + original_stem['loop_length']
            start_position = original_stem['start_position'] + start_idx
            end_position = start_position + total_length - 1
            
            new_stem = {
                'stem_structure': stem_structure,
                'left_arm': left_arm,
                'right_arm': right_arm,
                'loop_sequence': original_stem['loop_sequence'],
                'loop_structure': original_stem['loop_structure'],
                'stem_length': stem_length,
                'loop_length': original_stem['loop_length'],
                'stem_gc_content': original_stem['stem_gc_content'],  # 近似值
                'loop_unpaired_ratio': original_stem['loop_unpaired_ratio'],
                'start_position': start_position,
                'end_position': end_position,
                'total_length': total_length,
                'energy': self.calculate_stem_energy(left_arm, right_arm, original_stem['loop_sequence'])
            }
            return new_stem
        except Exception as e:
            print(f"Error creating sub-stem: {e}")
            return None


    def identify_consensus_motifs(self, significant_patterns, top_n=10):
        """识别共识motif"""
        print("Identifying consensus motifs...")
        
        if not significant_patterns:
            print("No significant patterns found for motif identification")
            return []
        
        # 选择最显著的模式
        top_patterns = significant_patterns[:top_n]
        
        motifs = []
        for pattern_info in top_patterns:
            pattern = pattern_info['pattern']
            
            # 分析模式特征
            motif_analysis = {
                'pattern': pattern,
                'length': len(pattern),
                'stem_regions': pattern.count('('),  # 假设对称
                'paired_bases': pattern.count('(') + pattern.count(')'),
                'unpaired_bases': pattern.count('.'),
                'paired_ratio': (pattern.count('(') + pattern.count(')')) / len(pattern) if len(pattern) > 0 else 0,
                'enrichment': pattern_info['enrichment'],
                'adjusted_p_value': pattern_info['adjusted_p_value'],
                'positive_count': pattern_info['positive_count'],
                'negative_count': pattern_info['negative_count']
            }
            
            # 检测常见motif类型
            if pattern.startswith('((') and pattern.endswith('))'):
                motif_type = "Long Stem"
            elif pattern.count('(') == 1 and pattern.count(')') == 1:
                motif_type = "Short Stem"
            elif pattern.count('.') > pattern.count('(') + pattern.count(')'):
                motif_type = "Loop-rich"
            else:
                motif_type = "Complex"
            
            motif_analysis['motif_type'] = motif_type
            motifs.append(motif_analysis)
        
        return motifs
    
    def create_pattern_visualization(self, pattern_analysis, characteristic_stats, output_file='pattern_analysis.png'):
        """创建模式分析可视化"""
        try:
            # 使用输出目录
            output_path = self.output_dir / output_file

            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
            fig.suptitle('Stem-loop Pattern Analysis Summary', fontsize=16, fontweight='bold')
            
            # 1. 富集模式散点图
            if pattern_analysis:
                top_patterns = pattern_analysis[:min(20, len(pattern_analysis))]
                enrichments = [min(p['enrichment'], 100) if p['enrichment'] != float('inf') else 100 for p in top_patterns]
                p_values = [-np.log10(p['adjusted_p_value'] + 1e-10) for p in top_patterns]
                pattern_lengths = [len(p['pattern']) for p in top_patterns]
                
                scatter = ax1.scatter(enrichments, p_values, c=pattern_lengths, 
                                    cmap='viridis', alpha=0.7, s=60)  # 修复：alpha是单个数值0.7
                ax1.set_xlabel('Enrichment (Positive/Negative)')
                ax1.set_ylabel('-log10(Adjusted p-value)')
                ax1.set_title('Stem-loop Pattern Enrichment')
                ax1.grid(True, alpha=0.3)
                plt.colorbar(scatter, ax=ax1, label='Pattern Length')
                
                # 添加显著性阈值线
                ax1.axhline(y=-np.log10(0.05), color='red', linestyle='--', alpha=0.7, label='p=0.05')
                ax1.axhline(y=-np.log10(0.01), color='darkred', linestyle='--', alpha=0.7, label='p=0.01')
                ax1.legend()
            else:
                ax1.text(0.5, 0.5, 'No pattern data available', 
                        ha='center', va='center', transform=ax1.transAxes)
                ax1.set_title('Stem-loop Pattern Enrichment')
            
            # 2. 特征差异条形图 - 这里是问题所在！
            if characteristic_stats:
                chars = list(characteristic_stats.keys())
                effect_sizes = [characteristic_stats[char]['cohens_d'] for char in chars]
                p_values = [characteristic_stats[char]['p_value'] for char in chars]
                
                # 根据p值设置颜色和透明度 - 修复alpha参数
                colors = []
                for p_val in p_values:
                    if p_val < 0.001:
                        colors.append('darkred')
                    elif p_val < 0.01:
                        colors.append('red')
                    elif p_val < 0.05:
                        colors.append('orange')
                    else:
                        colors.append('gray')
                
                # 修复：为每个条形设置单独的alpha值
                bars = ax2.barh(range(len(chars)), effect_sizes, color=colors, alpha=0.7)  # 统一alpha为0.7
                
                ax2.set_yticks(range(len(chars)))
                ax2.set_yticklabels(chars)
                ax2.set_xlabel("Cohen's d Effect Size")
                ax2.set_title('Stem Characteristic Differences\n(Color by significance)')
                ax2.grid(True, alpha=0.3)
                
                # 添加数值标签
                for i, bar in enumerate(bars):
                    width = bar.get_width()
                    ax2.text(width + 0.01 * (1 if width > 0 else -1), bar.get_y() + bar.get_height()/2,
                            f'{width:.2f}', ha='left' if width > 0 else 'right', va='center', fontsize=9)
            else:
                ax2.text(0.5, 0.5, 'No characteristic data available', 
                        ha='center', va='center', transform=ax2.transAxes)
                ax2.set_title('Stem Characteristic Differences')
            
            # 3. 模式频率分布
            if pattern_analysis:
                top_patterns = pattern_analysis[:min(10, len(pattern_analysis))]
                pos_counts = [p['positive_count'] for p in top_patterns]
                neg_counts = [p['negative_count'] for p in top_patterns]
                patterns_short = [p['pattern'][:10] + '...' if len(p['pattern']) > 10 else p['pattern'] for p in top_patterns]
                
                x = np.arange(len(patterns_short))
                width = 0.35
                
                ax3.bar(x - width/2, pos_counts, width, label='Positive', alpha=0.7, color='blue')  # 修复：alpha是单个数值
                ax3.bar(x + width/2, neg_counts, width, label='Negative', alpha=0.7, color='red')   # 修复：alpha是单个数值
                ax3.set_xlabel('Patterns')
                ax3.set_ylabel('Frequency')
                ax3.set_title('Top Pattern Frequency Distribution')
                ax3.set_xticks(x)
                ax3.set_xticklabels(patterns_short, rotation=45, ha='right', fontsize=8)
                ax3.legend()
                ax3.grid(True, alpha=0.3)
            else:
                ax3.text(0.5, 0.5, 'No pattern data available', 
                        ha='center', va='center', transform=ax3.transAxes)
                ax3.set_title('Pattern Frequency Distribution')
            
            # 4. 显著性总结
            if pattern_analysis:
                significant_count = sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.05)
                highly_significant = sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.01)
                very_highly_significant = sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.001)
                
                categories = ['Total', 'p<0.05', 'p<0.01', 'p<0.001']
                counts = [len(pattern_analysis), significant_count, highly_significant, very_highly_significant]
                
                colors = ['lightblue', 'lightgreen', 'orange', 'red']
                bars = ax4.bar(categories, counts, color=colors, alpha=0.7)  # 修复：alpha是单个数值
                ax4.set_ylabel('Number of Patterns')
                ax4.set_title('Statistical Significance Summary')
                ax4.grid(True, alpha=0.3)
                
                # 添加数值标签
                for i, (bar, count) in enumerate(zip(bars, counts)):
                    height = bar.get_height()
                    ax4.text(bar.get_x() + bar.get_width()/2, height + 0.1, 
                            str(count), ha='center', va='bottom')
            else:
                ax4.text(0.5, 0.5, 'No pattern data available', 
                        ha='center', va='center', transform=ax4.transAxes)
                ax4.set_title('Statistical Significance Summary')
            
            plt.tight_layout()
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"Visualization saved to {output_path}")
        except Exception as e:
            print(f"Error creating visualization: {e}")
            import traceback
            traceback.print_exc()  # 添加详细错误信息
    
    def save_results(self, pattern_analysis, characteristic_stats, motifs, output_prefix='stem_patterns'):
        """保存分析结果"""
        # 使用输出目录
        output_prefix = str(self.output_dir / output_prefix)    
        
        # 保存模式分析结果
        if pattern_analysis:
            pattern_df = pd.DataFrame(pattern_analysis)
            pattern_df.to_csv(f'{output_prefix}_analysis.csv', index=False)
            print(f"- Pattern analysis: {output_prefix}_analysis.csv")
        else:
            print("- No pattern analysis data to save")
        
        # 保存特征统计
        if characteristic_stats:
            char_df = pd.DataFrame(characteristic_stats).T
            char_df.to_csv(f'{output_prefix}_characteristics.csv')
            print(f"- Characteristics: {output_prefix}_characteristics.csv")
        else:
            print("- No characteristic data to save")
        
        # 保存motif分析
        if motifs:
            motif_df = pd.DataFrame(motifs)
            # 过滤正富集的motif
            positive_enriched_motifs = motif_df[
                (motif_df['enrichment'] > 1) | (motif_df['enrichment'] == float('inf'))
            ]
            if not positive_enriched_motifs.empty:
                positive_enriched_motifs.to_csv(f'{output_prefix}_motifs.csv', index=False)
                print(f"- Motifs: {output_prefix}_motifs.csv")
            else:
                print("- No positively enriched motifs found")
        else:
            print("- No motif data to save")
        
        # 保存总结报告
        summary = {
            'total_patterns_analyzed': len(pattern_analysis) if pattern_analysis else 0,
            'significant_patterns_p05': sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.05) if pattern_analysis else 0,
            'significant_patterns_p01': sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.01) if pattern_analysis else 0,
            'significant_patterns_p001': sum(1 for p in pattern_analysis if p['adjusted_p_value'] < 0.001) if pattern_analysis else 0,
            'top_pattern': pattern_analysis[0]['pattern'] if pattern_analysis else None,
            'max_enrichment': max(p['enrichment'] for p in pattern_analysis) if pattern_analysis else 0,
            'positive_stems_count': len([p for p in pattern_analysis if p['positive_count'] > 0]) if pattern_analysis else 0,
            'negative_stems_count': len([p for p in pattern_analysis if p['negative_count'] > 0]) if pattern_analysis else 0,
            'analysis_timestamp': pd.Timestamp.now().isoformat()
        }
        
        with open(f'{output_prefix}_summary.json', 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"- Summary: {output_prefix}_summary.json")
        print(f"\nResults saved to '{self.output_dir}' directory")

def extract_stems_single_standalone(seq, struct, min_stem_length=1, max_stem_length=5):
    """独立的stem提取函数，用于多进程"""
    try:
        # 创建临时实例
        temp_miner = StemPatternMiner(min_stem_length=min_stem_length, max_stem_length=max_stem_length)
        stems = temp_miner.extract_detailed_stem_loops(seq, struct, min_stem_length)
        
        # 计算能量
        for stem in stems:
            stem['energy'] = temp_miner.calculate_stem_energy(
                stem['left_arm'], stem['right_arm'], stem['loop_sequence']
            )
        
        return stems
    except Exception as e:
        print(f"Error extracting stems: {e}")
        return []


def main():
    """阶段3主函数"""
    print("="*60)
    print("STAGE 3: STEM-LOOP PATTERN MINING AND STATISTICAL VALIDATION")
    print("="*60)
    
    # 检查输入文件是否存在
    if not os.path.exists('positive_structures.txt') or not os.path.exists('negative_structures.txt'):
        print("Error: Input files 'positive_structures.txt' or 'negative_structures.txt' not found!")
        print("Please ensure these files exist in the current directory.")
        return
    
    try:
        # 初始化模式挖掘器
        miner = StemPatternMiner()
        
        # 使用 run 方法执行完整流程
        pattern_analysis, positive_stems, negative_stems = miner.run(
            'positive_structures.txt', 'negative_structures.txt'
        )
        
        if pattern_analysis is None:
            print("Pattern mining failed!")
            return
        
        print("\nAnalysis completed successfully!")
        
    except Exception as e:
        print(f"Error during analysis: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()