import pandas as pd
import numpy as np
from collections import defaultdict
import re
import json
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Dict, List, Tuple
from pathlib import Path

class MotifSequenceMapper:
    def __init__(self, output_dir=".", max_motifs=5, similar_matching=True, similarity_threshold=0.9, max_length_diff=1):
        self.output_dir = Path(output_dir)
        self.sequence_motifs = defaultdict(list)
        self.motif_sequences = defaultdict(list)
        self.sequence_info = {}
        self.max_motifs = max_motifs
        # 设置相似性匹配参数
        self.similar_matching = similar_matching
        self.similarity_threshold = similarity_threshold
        self.max_length_diff = max_length_diff
        
    def run(self, pattern_analysis):
        """运行完整的motif映射流程 - 修复max_motifs限制问题"""
        print("Starting motif-sequence mapping analysis...")
        print(f"Similar matching: {self.similar_matching}")
        print(f"Similarity threshold: {self.similarity_threshold}")
        print(f"Max length difference: {self.max_length_diff}")
        
        try:
            # 选择显著的motif (p < 0.05)
            significant_motifs = [p for p in pattern_analysis if p['adjusted_p_value'] < 0.05]
            
            # 按显著性排序
            significant_motifs_sorted = sorted(significant_motifs, key=lambda x: x['adjusted_p_value'])
            
            # 关键修改：确保至少有max_motifs个motif用于分析
            if len(significant_motifs_sorted) >= self.max_motifs:
                # 如果显著motif足够，使用前max_motifs个
                selected_motifs = significant_motifs_sorted[:self.max_motifs]
                print(f"Using top {self.max_motifs} significant motifs (p < 0.05)")
            else:
                # 如果显著motif不足，补充其他motif
                selected_motifs = significant_motifs_sorted.copy()
                print(f"Only {len(significant_motifs_sorted)} significant motifs found (p < 0.05)")
                
                # 按富集度排序补充剩余的motif
                remaining_slots = self.max_motifs - len(selected_motifs)
                if remaining_slots > 0:
                    # 获取所有motif，排除已选中的
                    all_motifs_sorted = sorted(pattern_analysis, 
                                            key=lambda x: (x.get('enrichment', 0), -x.get('adjusted_p_value', 1)), 
                                            reverse=True)
                    
                    # 选择未选中的高富集度motif
                    for motif in all_motifs_sorted:
                        if motif not in selected_motifs and remaining_slots > 0:
                            selected_motifs.append(motif)
                            remaining_slots -= 1
                            print(f"  Adding motif by enrichment: {motif['pattern']} (enrichment={motif.get('enrichment', 0):.2f})")
                
                print(f"Final selection: {len(selected_motifs)} motifs ({len(significant_motifs_sorted)} significant + {self.max_motifs - len(significant_motifs_sorted)} by enrichment)")
            
            # 验证Step3的计数
            total_step3_positive = sum(m.get('positive_count', 0) for m in selected_motifs)
            total_step3_negative = sum(m.get('negative_count', 0) for m in selected_motifs)
            print(f"Step3 total counts - Positive stems: {total_step3_positive}, Negative stems: {total_step3_negative}")
            
            # 首先尝试加载Step3的stem数据
            print("Attempting to load Step3 stem data...")
            positive_stems, negative_stems = self.load_step3_stem_data()
            
            if not positive_stems or not negative_stems:
                print("Warning: Could not load Step3 stem data, using direct sequence analysis...")
                return self.run_direct_sequence_analysis(selected_motifs)  # 传入selected_motifs
            
            print(f"Loaded {len(positive_stems)} positive stems, {len(negative_stems)} negative stems from Step3")
            
            # 映射motif到Step3的stem数据 - 传入selected_motifs
            all_mappings = self.map_motifs_to_step3_stems_with_flexible_matching(
                selected_motifs, positive_stems, negative_stems
            )
            
            print(f"Step4 mapped {len(all_mappings)} total stem instances")
            
            # 验证计数一致性
            actual_positive_mappings = len([m for m in all_mappings if m['label'] == 'positive'])
            actual_negative_mappings = len([m for m in all_mappings if m['label'] == 'negative'])
            
            print(f"Step4 instance counts - Positive: {actual_positive_mappings}, Negative: {actual_negative_mappings}")
            
            if total_step3_positive > 0:
                mapping_ratio = actual_positive_mappings / total_step3_positive
                print(f"Mapping ratio (Step4 instances / Step3 stems): {mapping_ratio:.2f}")
                if mapping_ratio < 0.95:
                    print("Warning: Low mapping ratio, some stems may not be matched correctly")
            
            # 创建报告和可视化
            mapping_df, motif_stats_df, sequence_summary_df = self.create_detailed_reports(
                all_mappings, selected_motifs  # 传入selected_motifs
            )
            
            # 创建可视化
            self.create_enhanced_visualizations(mapping_df, motif_stats_df, selected_motifs)  # 传入selected_motifs
            
            print(f"\nMotif mapping completed successfully!")
            return True
            
        except Exception as e:
            print(f"Error during motif mapping: {str(e)}")
            import traceback
            traceback.print_exc()
            return False

    def run_direct_sequence_analysis(self, pattern_analysis):
        """当无法加载Step3数据时，直接从序列文件分析 - 修复max_motifs限制"""
        print("Running direct sequence analysis...")
        
        try:
            # 选择显著的motif (p < 0.05)
            significant_motifs = [p for p in pattern_analysis if p['adjusted_p_value'] < 0.05]
            significant_motifs_sorted = sorted(significant_motifs, key=lambda x: x['adjusted_p_value'])
            
            # 关键修改：确保至少有max_motifs个motif用于分析
            if len(significant_motifs_sorted) >= self.max_motifs:
                selected_motifs = significant_motifs_sorted[:self.max_motifs]
                print(f"Using top {self.max_motifs} significant motifs")
            else:
                selected_motifs = significant_motifs_sorted.copy()
                print(f"Only {len(significant_motifs_sorted)} significant motifs found")
                
                # 按富集度排序补充剩余的motif
                remaining_slots = self.max_motifs - len(selected_motifs)
                if remaining_slots > 0:
                    all_motifs_sorted = sorted(pattern_analysis, 
                                            key=lambda x: (x.get('enrichment', 0), -x.get('adjusted_p_value', 1)), 
                                            reverse=True)
                    
                    for motif in all_motifs_sorted:
                        if motif not in selected_motifs and remaining_slots > 0:
                            selected_motifs.append(motif)
                            remaining_slots -= 1
                            print(f"  Adding motif by enrichment: {motif['pattern']} (enrichment={motif.get('enrichment', 0):.2f})")
                
                print(f"Final selection: {len(selected_motifs)} motifs")
            
            print(f"Processing {len(selected_motifs)} motifs for direct sequence analysis")

            # 直接解析序列文件
            positive_file = self.output_dir / 'positive_structures.txt'
            negative_file = self.output_dir / 'negative_structures.txt'
            
            if not positive_file.exists() or not negative_file.exists():
                print("Error: Sequence files not found!")
                return False
            
            positive_sequences, positive_structures = self.parse_sequences_with_names(str(positive_file))
            negative_sequences, negative_structures = self.parse_sequences_with_names(str(negative_file))
            
            print(f"Loaded {len(positive_sequences)} positive sequences, {len(negative_sequences)} negative sequences")
            
            # 直接提取stem并映射 - 使用selected_motifs
            all_mappings = []
            for motif_info in selected_motifs:  # 改为使用selected_motifs
                motif_pattern = motif_info['pattern']
                print(f"Processing motif: {motif_pattern}")
                
                # 在阳性序列中查找
                for seq_name, structure in positive_structures.items():
                    if seq_name in positive_sequences:
                        stems = self.extract_stem_loops_compatible(positive_sequences[seq_name], structure, seq_name)
                        for stem in stems:
                            if self.is_stem_match_strict(stem['stem_structure'], motif_pattern):
                                mapping = stem.copy()
                                mapping.update({
                                    'motif_pattern': motif_pattern,
                                    'matched_pattern': stem['stem_structure'],
                                    'match_type': 'exact' if stem['stem_structure'] == motif_pattern else 'similar',
                                    'label': 'positive',
                                    'motif_enrichment': motif_info.get('enrichment', 0),
                                    'motif_p_value': motif_info.get('adjusted_p_value', 1.0)
                                })
                                all_mappings.append(mapping)
                
                # 在阴性序列中查找
                for seq_name, structure in negative_structures.items():
                    if seq_name in negative_sequences:
                        stems = self.extract_stem_loops_compatible(negative_sequences[seq_name], structure, seq_name)
                        for stem in stems:
                            if self.is_stem_match_strict(stem['stem_structure'], motif_pattern):
                                mapping = stem.copy()
                                mapping.update({
                                    'motif_pattern': motif_pattern,
                                    'matched_pattern': stem['stem_structure'],
                                    'match_type': 'exact' if stem['stem_structure'] == motif_pattern else 'similar',
                                    'label': 'negative',
                                    'motif_enrichment': motif_info.get('enrichment', 0),
                                    'motif_p_value': motif_info.get('adjusted_p_value', 1.0)
                                })
                                all_mappings.append(mapping)
            
            # 创建报告
            mapping_df, motif_stats_df, sequence_summary_df = self.create_detailed_reports(
                all_mappings, selected_motifs  # 传入selected_motifs
            )
            
            # 创建可视化
            self.create_enhanced_visualizations(mapping_df, motif_stats_df, selected_motifs)  # 传入selected_motifs
            
            print(f"Direct sequence analysis completed with {len(all_mappings)} mappings!")
            return True
            
        except Exception as e:
            print(f"Error in direct sequence analysis: {str(e)}")
            return False

    def load_step3_stem_data(self):
        """加载Step3生成的stem数据 - 修复版本"""
        try:
            # 查找所有可能的pickle文件
            pickle_files = list(self.output_dir.glob('*_stems.pkl'))
            print(f"Looking for stem files in: {self.output_dir}")
            # print(f"Found pickle files: {[f.name for f in pickle_files]}")
            
            positive_stems = []
            negative_stems = []
            
            for pkl_file in pickle_files:
                try:
                    import joblib
                    if 'positive' in pkl_file.name.lower():
                        print(f"Loading positive stems from: {pkl_file}")
                        positive_stems = joblib.load(pkl_file)
                        print(f"Loaded {len(positive_stems)} positive stems")
                    elif 'negative' in pkl_file.name.lower():
                        print(f"Loading negative stems from: {pkl_file}")
                        negative_stems = joblib.load(pkl_file)
                        print(f"Loaded {len(negative_stems)} negative stems")
                except Exception as e:
                    print(f"Error loading {pkl_file}: {e}")
            
            if positive_stems and negative_stems:
                print(f"Successfully loaded {len(positive_stems)} positive stems, {len(negative_stems)} negative stems")
                return positive_stems, negative_stems
            else:
                print("Could not find both positive and negative stem files, trying CSV reconstruction...")
                return self.reconstruct_stems_from_analysis()
                
        except Exception as e:
            print(f"Error loading Step3 stem data: {e}")
            return self.reconstruct_stems_from_analysis()
    
    def load_stems_from_csv(self):
        """从CSV文件加载stem数据"""
        try:
            positive_stems = []
            negative_stems = []
            
            # 尝试加载Step3的分析结果
            analysis_file = self.output_dir / 'stem_patterns_analysis.csv'
            if analysis_file.exists():
                print("Reconstructing stems from analysis data...")
                return self.reconstruct_stems_from_analysis()
            
            print("No Step3 stem data found. Please ensure Step3 ran successfully.")
            return [], []
            
        except Exception as e:
            print(f"Error loading stems from CSV: {e}")
            return [], []

    def reconstruct_stems_from_analysis(self):
        """从Step3的分析结果重建stem数据 - 增强版本"""
        try:
            # 加载序列文件
            positive_file = self.output_dir / 'positive_structures.txt'
            negative_file = self.output_dir / 'negative_structures.txt'
            
            positive_sequences, positive_structures = self.parse_sequences_with_names(str(positive_file))
            negative_sequences, negative_structures = self.parse_sequences_with_names(str(negative_file))
            
            print(f"Loaded {len(positive_sequences)} positive sequences, {len(negative_sequences)} negative sequences")
            
            # 使用与Step3相同的参数提取stem
            print("Reconstructing stems using Step3-compatible extraction...")
            positive_stems = []
            negative_stems = []
            
            # 提取阳性stem
            for seq_name, structure in positive_structures.items():
                if seq_name in positive_sequences:
                    stems = self.extract_stem_loops_compatible(positive_sequences[seq_name], structure, seq_name)
                    positive_stems.extend(stems)
            
            # 提取阴性stem  
            for seq_name, structure in negative_structures.items():
                if seq_name in negative_sequences:
                    stems = self.extract_stem_loops_compatible(negative_sequences[seq_name], structure, seq_name)
                    negative_stems.extend(stems)
            
            print(f"Reconstructed {len(positive_stems)} positive stems, {len(negative_stems)} negative stems")
            return positive_stems, negative_stems
            
        except Exception as e:
            print(f"Error reconstructing stems: {e}")
            return [], []

    def extract_stem_loops_compatible(self, sequence: str, structure: str, sequence_name: str) -> List[Dict]:
        """使用与Step3兼容的stem提取逻辑 - 修复版本"""
        stems = []
        stack = []
        
        try:
            for i, char in enumerate(structure):
                if char == '(':
                    stack.append(i)
                elif char == ')':
                    if stack:
                        start = stack.pop()
                        # 计算stem长度 - 使用与Step3相同的逻辑
                        stem_length = 0
                        j = 0
                        while (start + j < len(structure) and 
                            i - j >= 0 and 
                            structure[start + j] == '(' and 
                            structure[i - j] == ')'):
                            stem_length += 1
                            j += 1
                        
                        # 使用与Step3相同的最小stem长度
                        if stem_length >= 1:  # 与Step3的min_stem_length=1保持一致
                            stem_struct = structure[start:i+1]
                            
                            # 提取序列信息
                            left_arm = sequence[start:start+stem_length]
                            right_arm_start = i - stem_length + 1
                            right_arm = sequence[right_arm_start:i+1] if right_arm_start >= 0 else ""
                            
                            stems.append({
                                'sequence_name': sequence_name,
                                'stem_structure': stem_struct,
                                'start_position': start,
                                'end_position': i,
                                'stem_length': stem_length,
                                'left_arm': left_arm,
                                'right_arm': right_arm,
                                'sequence_region': sequence[start:i+1],
                                'full_sequence': sequence,
                                'full_structure': structure
                            })
            return stems
        except Exception as e:
            print(f"Error extracting stems from {sequence_name}: {str(e)}")
            return []

    def map_motifs_to_step3_stems_with_flexible_matching(self, significant_motifs: List[Dict], 
                                                        positive_stems: List[Dict], 
                                                        negative_stems: List[Dict]) -> List[Dict]:
        """将motif映射到Step3的stem数据 - 使用更严格的匹配"""
        print("Mapping motifs to Step3 stem data with STRICT matching...")
        
        # 确保significant_motifs应用了max_motifs限制
        if len(significant_motifs) >= self.max_motifs:
            print(f"Limiting motifs to top {self.max_motifs} in mapping function")
            significant_motifs = significant_motifs[:self.max_motifs]
        elif len(significant_motifs) < self.max_motifs:
            print(f"Note: Only {len(significant_motifs)} motifs available (less than max_motifs={self.max_motifs})")
            # 如果significant_motifs数量不足，保持原样使用所有可用的motifs
            # 这里不需要额外操作，因为significant_motifs已经包含了所有可用的motifs
        
        all_mappings = []
        
        for i, motif_info in enumerate(significant_motifs):
            motif_pattern = motif_info['pattern']
            print(f"Processing motif {i+1}/{len(significant_motifs)}: {motif_pattern}")
            
            try:
                # 在阳性stem中查找匹配 - 使用更严格的匹配
                positive_matches = []
                for stem in positive_stems:
                    if self.is_stem_match_strict(stem['stem_structure'], motif_pattern):
                        match = stem.copy()
                        match.update({
                            'motif_pattern': motif_pattern,
                            'matched_pattern': stem['stem_structure'],
                            'match_type': 'exact' if stem['stem_structure'] == motif_pattern else 'similar',
                            'label': 'positive'
                        })
                        positive_matches.append(match)
                
                # 在阴性stem中查找匹配
                negative_matches = []
                for stem in negative_stems:
                    if self.is_stem_match_strict(stem['stem_structure'], motif_pattern):
                        match = stem.copy()
                        match.update({
                            'motif_pattern': motif_pattern,
                            'matched_pattern': stem['stem_structure'],
                            'match_type': 'exact' if stem['stem_structure'] == motif_pattern else 'similar', 
                            'label': 'negative'
                        })
                        negative_matches.append(match)
                
                # 限制每个序列中每个motif的最大匹配数
                positive_matches = self.limit_matches_per_sequence(positive_matches, max_per_sequence=1)
                negative_matches = self.limit_matches_per_sequence(negative_matches, max_per_sequence=1)
                
                # 统计信息
                step3_positive_count = motif_info.get('positive_count', 0)
                step3_negative_count = motif_info.get('negative_count', 0)
                
                exact_positive = len([m for m in positive_matches if m['match_type'] == 'exact'])
                similar_positive = len([m for m in positive_matches if m['match_type'] == 'similar'])
                
                print(f"  Step3 count: {step3_positive_count}, Step4 matches: {len(positive_matches)} (exact: {exact_positive}, similar: {similar_positive})")
                
                # 添加到映射结果
                for match in positive_matches:
                    match.update({
                        'motif_enrichment': motif_info.get('enrichment', 0),
                        'motif_p_value': motif_info.get('adjusted_p_value', 1.0),
                        'step3_positive_count': step3_positive_count,
                        'step3_negative_count': step3_negative_count
                    })
                    all_mappings.append(match)
                
                for match in negative_matches:
                    match.update({
                        'motif_enrichment': motif_info.get('enrichment', 0),
                        'motif_p_value': motif_info.get('adjusted_p_value', 1.0),
                        'step3_positive_count': step3_positive_count,
                        'step3_negative_count': step3_negative_count
                    })
                    all_mappings.append(match)
                
                # 存储motif统计
                self.motif_sequences[motif_pattern] = {
                    'stats': {
                        'step3_positive_count': step3_positive_count,
                        'step3_negative_count': step3_negative_count,
                        'step4_positive_matches': len(positive_matches),
                        'step4_negative_matches': len(negative_matches),
                        'exact_positive_matches': exact_positive,
                        'similar_positive_matches': similar_positive
                    },
                    'positive_matches': positive_matches,
                    'negative_matches': negative_matches
                }
                
            except Exception as e:
                print(f"Error processing motif {motif_pattern}: {str(e)}")
                continue
        
        return all_mappings


    def is_stem_match_strict(self, stem_structure: str, motif_pattern: str) -> bool:
        """判断stem结构是否与motif匹配 - 使用更严格的标准"""
        # 1. 完全匹配 - 最高优先级
        if stem_structure == motif_pattern:
            return True
        
        # 如果关闭相似性匹配，只进行完全匹配
        if not self.similar_matching:
            return False
        
        # 2. 长度差异不能太大（使用参数控制）
        len_diff = abs(len(stem_structure) - len(motif_pattern))
        if len_diff > self.max_length_diff:
            return False
        
        # 3. 结构相似性必须很高（使用参数控制的阈值）
        if self.structure_similarity_strict(stem_structure, motif_pattern) >= self.similarity_threshold:
            return True
        
        # 4. 配对比例必须相似（差异小于0.1）
        paired_ratio1 = (stem_structure.count('(') + stem_structure.count(')')) / len(stem_structure)
        paired_ratio2 = (motif_pattern.count('(') + motif_pattern.count(')')) / len(motif_pattern)
        if abs(paired_ratio1 - paired_ratio2) > 0.1:
            return False
        
        # 5. 必须是motif的子串（包含关系），且长度差异小
        if (motif_pattern in stem_structure and len_diff <= self.max_length_diff) or \
        (stem_structure in motif_pattern and len_diff <= self.max_length_diff):
            return True
        
        return False

    def structure_similarity_strict(self, pattern1: str, pattern2: str) -> float:
        """计算两个结构模式的相似性得分 - 更严格版本"""
        try:
            import Levenshtein
            # 计算编辑距离相似性 - 使用更严格的权重
            edit_sim = 1 - (Levenshtein.distance(pattern1, pattern2) / max(len(pattern1), len(pattern2)))
            
            # 计算结构特征相似性
            struct_sim = self.structural_feature_similarity_strict(pattern1, pattern2)
            
            # 加权综合相似性 - 更注重结构特征
            combined_sim = 0.4 * edit_sim + 0.6 * struct_sim
            
            return combined_sim
        except ImportError:
            # 简单的相似性计算 - 更严格
            if len(pattern1) == 0 or len(pattern2) == 0:
                return 0.0
            
            # 计算基本特征相似性
            len_sim = 1 - abs(len(pattern1) - len(pattern2)) / max(len(pattern1), len(pattern2))
            paired_ratio1 = (pattern1.count('(') + pattern1.count(')')) / len(pattern1)
            paired_ratio2 = (pattern2.count('(') + pattern2.count(')')) / len(pattern2)
            ratio_sim = 1 - abs(paired_ratio1 - paired_ratio2)
            
            # 要求两者都很高
            return min(len_sim, ratio_sim)

    def structural_feature_similarity_strict(self, pattern1: str, pattern2: str) -> float:
        """基于结构特征的相似性计算 - 更严格版本"""
        features1 = self.extract_structural_features_strict(pattern1)
        features2 = self.extract_structural_features_strict(pattern2)
        
        if not features1 or not features2:
            return 0.0
        
        # 计算特征向量相似性 - 使用余弦相似性
        common_features = set(features1.keys()) | set(features2.keys())
        dot_product = sum(features1.get(f, 0) * features2.get(f, 0) for f in common_features)
        norm1 = np.sqrt(sum(v**2 for v in features1.values()))
        norm2 = np.sqrt(sum(v**2 for v in features2.values()))
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        similarity = dot_product / (norm1 * norm2)
        
        # 应用更严格的标准：相似性必须很高
        return similarity if similarity >= 0.8 else 0.0

    def extract_structural_features_strict(self, pattern: str) -> Dict:
        """提取结构特征向量 - 更详细的特征"""
        features = {}
        
        if len(pattern) == 0:
            return features
            
        # 基本结构特征
        features['length'] = len(pattern)
        features['paired_count'] = pattern.count('(') + pattern.count(')')
        features['unpaired_count'] = pattern.count('.')
        features['paired_ratio'] = features['paired_count'] / len(pattern)
        
        # 详细的结构特征
        features['left_paired'] = pattern.count('(')
        features['right_paired'] = pattern.count(')')
        features['balance'] = 1 - abs(features['left_paired'] - features['right_paired']) / max(features['left_paired'], features['right_paired']) if max(features['left_paired'], features['right_paired']) > 0 else 1.0
        
        # 连续性特征
        features['max_stem_run'] = self.max_consecutive_count(pattern, '(')
        features['max_loop_run'] = self.max_consecutive_count(pattern, '.')
        features['stem_continuity'] = features['max_stem_run'] / features['left_paired'] if features['left_paired'] > 0 else 0
        
        return features
    
    def limit_matches_per_sequence(self, matches: List[Dict], max_per_sequence: int = 10) -> List[Dict]:
        """限制每个序列中每个motif的最大匹配数为10个"""
        sequence_motif_count = defaultdict(int)
        filtered_matches = []
        
        for match in matches:
            seq_name = match['sequence_name']
            motif_pattern = match['motif_pattern']
            key = (seq_name, motif_pattern)
            
            if sequence_motif_count[key] < max_per_sequence:
                sequence_motif_count[key] += 1
                filtered_matches.append(match)
            # 否则跳过这个匹配
        
        print(f"  Limited matches: {len(matches)} -> {len(filtered_matches)} (max {max_per_sequence} per sequence per motif)")
        return filtered_matches

    def max_consecutive_count(self, pattern: str, char: str) -> int:
        """计算字符连续出现的最大次数"""
        max_count = 0
        current_count = 0
        
        for c in pattern:
            if c == char:
                current_count += 1
                max_count = max(max_count, current_count)
            else:
                current_count = 0
        
        return max_count

    def parse_sequences_with_names(self, filename: str) -> Tuple[Dict, Dict]:
        """解析带序列名的文件"""
        sequences = {}
        structures = {}
        current_name = ""
        
        try:
            with open(filename, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line.startswith('>'):
                        current_name = line[1:]  # 去掉'>'
                        sequences[current_name] = ""
                        structures[current_name] = ""
                    elif line and not line.startswith('>'):
                        if all(c in 'ACGUacgu.-' for c in line):
                            sequences[current_name] = line.upper().replace('-', '').replace('.', '')
                        elif any(c in '().' for c in line):
                            structures[current_name] = line
            
            # 清理空条目
            sequences = {k: v for k, v in sequences.items() if v}
            structures = {k: v for k, v in structures.items() if v}
            
            return sequences, structures
            
        except FileNotFoundError:
            print(f"Error: File {filename} not found.")
            return {}, {}
        except Exception as e:
            print(f"Error parsing {filename}: {str(e)}")
            return {}, {}

    def create_detailed_reports(self, all_mappings: List[Dict], significant_motifs: List[Dict], 
                            output_prefix: str = 'motif_mapping') -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """创建详细报告 - 增强版本"""
        print("Creating detailed reports...")
        
        # 使用输出目录
        output_prefix = str(self.output_dir / output_prefix)
        
        try:
            # 1. 主要映射结果 - 包含所有stem实例
            if all_mappings:
                # 为每个映射添加唯一标识符
                for i, mapping in enumerate(all_mappings):
                    mapping['unique_instance_id'] = f"{mapping['sequence_name']}_{mapping['start_position']}_{mapping['end_position']}_{i}"
                
                mapping_df = pd.DataFrame(all_mappings)
                
                # 重新排列列的顺序，让重要信息在前面
                base_columns = ['sequence_name', 'motif_pattern', 'matched_pattern', 'match_type', 'label', 
                              'start_position', 'end_position', 'stem_length', 'unique_instance_id']
                
                # 添加其他列
                other_columns = [col for col in mapping_df.columns if col not in base_columns]
                new_column_order = base_columns + other_columns
                
                mapping_df = mapping_df.reindex(columns=new_column_order)
                mapping_df.to_csv(f'{output_prefix}_detailed.csv', index=False)
                
                # 输出统计信息
                total_instances = len(mapping_df)
                positive_instances = len(mapping_df[mapping_df['label'] == 'positive'])
                negative_instances = len(mapping_df[mapping_df['label'] == 'negative'])
                exact_matches = len(mapping_df[mapping_df['match_type'] == 'exact'])
                similar_matches = len(mapping_df[mapping_df['match_type'] == 'similar'])
                
                print(f"Detailed mapping: {positive_instances} positive, {negative_instances} negative instances")
                print(f"Match types: {exact_matches} exact, {similar_matches} similar matches")
                
                # 检查序列重复情况
                seq_motif_counts = mapping_df.groupby(['sequence_name', 'motif_pattern']).size()
                sequences_with_multiple = seq_motif_counts[seq_motif_counts > 1]
                if len(sequences_with_multiple) > 0:
                    print(f"  {len(sequences_with_multiple)} sequence-motif pairs have multiple stem instances")
            else:
                mapping_df = pd.DataFrame()
                print("Warning: No mappings found to save.")
            
            # 2. Motif统计总结 - 对比Step3和Step4的计数
            motif_stats = []
            for motif_pattern, data in self.motif_sequences.items():
                stats = data['stats'].copy()
                stats['motif_pattern'] = motif_pattern
                
                # 计算计数一致性
                step3_pos = stats.get('step3_positive_count', 0)
                step4_pos = stats.get('step4_positive_matches', 0)
                if step3_pos > 0:
                    stats['positive_count_ratio'] = step4_pos / step3_pos
                else:
                    stats['positive_count_ratio'] = 0
                    
                step3_neg = stats.get('step3_negative_count', 0)
                step4_neg = stats.get('step4_negative_matches', 0)
                if step3_neg > 0:
                    stats['negative_count_ratio'] = step4_neg / step3_neg
                else:
                    stats['negative_count_ratio'] = 0
                
                motif_stats.append(stats)
            
            motif_stats_df = pd.DataFrame(motif_stats)
            if not motif_stats_df.empty:
                # 重新排列列的顺序
                stat_columns = ['motif_pattern', 'step3_positive_count', 'step4_positive_matches', 
                              'exact_positive_matches', 'similar_positive_matches', 'positive_count_ratio',
                              'step3_negative_count', 'step4_negative_matches', 'negative_count_ratio']
                
                existing_columns = [col for col in stat_columns if col in motif_stats_df.columns]
                other_columns = [col for col in motif_stats_df.columns if col not in existing_columns]
                
                motif_stats_df = motif_stats_df.reindex(columns=existing_columns + other_columns)
                motif_stats_df.to_csv(f'{output_prefix}_statistics.csv', index=False)
                
                # 输出计数一致性统计
                avg_positive_ratio = motif_stats_df['positive_count_ratio'].mean()
                avg_negative_ratio = motif_stats_df['negative_count_ratio'].mean()
                print(f"Count consistency - Positive: {avg_positive_ratio:.2f}, Negative: {avg_negative_ratio:.2f}")
            else:
                print("Warning: No motif statistics to save.")
            
            # 3. 每个序列的motif汇总
            sequence_summary = []
            if not mapping_df.empty:
                for seq_name in set(mapping_df['sequence_name']):
                    seq_data = mapping_df[mapping_df['sequence_name'] == seq_name]
                    label = seq_data['label'].iloc[0] if len(seq_data) > 0 else 'unknown'
                    
                    # 统计每个motif在该序列中的出现次数
                    motif_counts = seq_data['motif_pattern'].value_counts()
                    total_stems = len(seq_data)
                    unique_motifs = len(motif_counts)
                    
                    sequence_summary.append({
                        'sequence_name': seq_name,
                        'label': label,
                        'total_stem_instances': total_stems,
                        'unique_motifs': unique_motifs,
                        'motifs_with_counts': '; '.join([f"{motif}({count})" for motif, count in motif_counts.items()])
                    })
                
                sequence_summary_df = pd.DataFrame(sequence_summary)
                sequence_summary_df.to_csv(f'{output_prefix}_sequence_summary.csv', index=False)
                
                # 输出序列统计
                avg_stems_per_seq = sequence_summary_df['total_stem_instances'].mean()
                max_stems_per_seq = sequence_summary_df['total_stem_instances'].max()
                print(f"Sequence statistics: {avg_stems_per_seq:.1f} avg stems per sequence, max: {max_stems_per_seq}")
            else:
                sequence_summary_df = pd.DataFrame()
            
            return mapping_df, motif_stats_df, sequence_summary_df
            
        except Exception as e:
            print(f"Error creating reports: {str(e)}")
            return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    def create_enhanced_visualizations(self, mapping_df: pd.DataFrame, motif_stats_df: pd.DataFrame, 
                                    significant_motifs: List[Dict]):
        """创建增强的可视化"""
        print("Creating enhanced visualizations...")
        
        try:
            output_file = self.output_dir / 'motif_mapping_analysis.png'

            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
            fig.suptitle('Motif-Sequence Mapping Analysis', fontsize=16, fontweight='bold')
            
            # 1. 映射比例条形图
            if not motif_stats_df.empty:
                top_motifs = motif_stats_df.head(10)
                motifs_short = [self._create_motif_abbreviation(m) for m in top_motifs['motif_pattern']]
                
                x = np.arange(len(motifs_short))
                width = 0.35
                
                # Step3 vs Step4计数对比
                step3_pos = top_motifs['step3_positive_count']
                step4_pos = top_motifs['step4_positive_matches']
                
                ax1.bar(x - width/2, step3_pos, width, label='Step3 Count', alpha=0.7, color='lightblue')
                ax1.bar(x + width/2, step4_pos, width, label='Step4 Matches', alpha=0.7, color='blue')
                ax1.set_xlabel('Motifs')
                ax1.set_ylabel('Count')
                ax1.set_title('Step3 vs Step4 Count Comparison\n(Positive Stems)')
                ax1.set_xticks(x)
                ax1.set_xticklabels(motifs_short, rotation=45, ha='right')
                ax1.legend()
                ax1.grid(True, alpha=0.3)
                
                # 添加数值标签
                for i, (s3, s4) in enumerate(zip(step3_pos, step4_pos)):
                    ax1.text(i - width/2, s3 + 0.1, str(int(s3)), ha='center', va='bottom', fontsize=8)
                    ax1.text(i + width/2, s4 + 0.1, str(int(s4)), ha='center', va='bottom', fontsize=8)
            else:
                ax1.text(0.5, 0.5, 'No motif statistics available', 
                        ha='center', va='center', transform=ax1.transAxes)
                ax1.set_title('Step3 vs Step4 Count Comparison')
            
            # 2. 匹配类型分布
            if not mapping_df.empty:
                match_type_counts = mapping_df['match_type'].value_counts()
                colors = ['green', 'orange']
                
                ax2.pie(match_type_counts.values, labels=match_type_counts.index, 
                       autopct='%1.1f%%', colors=colors, startangle=90)
                ax2.set_title('Match Type Distribution')
            else:
                ax2.text(0.5, 0.5, 'No mapping data available', 
                        ha='center', va='center', transform=ax2.transAxes)
                ax2.set_title('Match Type Distribution')
            
            # 3. 序列的stem数量分布
            if not mapping_df.empty:
                positive_seqs = mapping_df[mapping_df['label'] == 'positive']['sequence_name'].unique()
                negative_seqs = mapping_df[mapping_df['label'] == 'negative']['sequence_name'].unique()
                
                positive_stem_counts = []
                for seq in positive_seqs:
                    count = len(mapping_df[(mapping_df['sequence_name'] == seq) & (mapping_df['label'] == 'positive')])
                    positive_stem_counts.append(count)
                
                negative_stem_counts = []
                for seq in negative_seqs:
                    count = len(mapping_df[(mapping_df['sequence_name'] == seq) & (mapping_df['label'] == 'negative')])
                    negative_stem_counts.append(count)
                
                if positive_stem_counts:
                    ax3.hist(positive_stem_counts, alpha=0.7, label='Positive', 
                            bins=10, color='red', density=True)
                if negative_stem_counts:
                    ax3.hist(negative_stem_counts, alpha=0.7, label='Negative', 
                            bins=10, color='blue', density=True)
                
                ax3.set_xlabel('Number of Stem Instances per Sequence')
                ax3.set_ylabel('Density')
                ax3.set_title('Distribution of Stem Instances\nper Sequence')
                ax3.legend()
                ax3.grid(True, alpha=0.3)
            else:
                ax3.text(0.5, 0.5, 'No mapping data available', 
                        ha='center', va='center', transform=ax3.transAxes)
                ax3.set_title('Stem Instance Distribution')
            
            # 4. 映射比例分布
            if not motif_stats_df.empty and 'positive_count_ratio' in motif_stats_df.columns:
                ratios = motif_stats_df['positive_count_ratio']
                ax4.hist(ratios, bins=15, alpha=0.7, color='purple', edgecolor='black')
                ax4.axvline(ratios.mean(), color='red', linestyle='--', label=f'Mean: {ratios.mean():.2f}')
                ax4.set_xlabel('Mapping Ratio (Step4/Step3)')
                ax4.set_ylabel('Frequency')
                ax4.set_title('Distribution of Mapping Ratios\nAcross Motifs')
                ax4.legend()
                ax4.grid(True, alpha=0.3)
            else:
                ax4.text(0.5, 0.5, 'No ratio data available', 
                        ha='center', va='center', transform=ax4.transAxes)
                ax4.set_title('Mapping Ratio Distribution')
            
            plt.tight_layout()
            plt.savefig(output_file, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"Visualization saved to {output_file}")
            
        except Exception as e:
            print(f"Error creating visualizations: {str(e)}")

    def _create_motif_abbreviation(self, motif: str, max_length: int = 15) -> str:
        """创建motif的缩写名称用于显示"""
        if len(motif) <= max_length:
            return motif
        return motif[:max_length-3] + "..."

    def run_fallback(self, pattern_analysis):
        """备用方法 - 当无法加载Step3数据时使用"""
        print("Using fallback method with sequence extraction...")
        # 这里可以实现在没有Step3数据时的备选方案
        return False

def main():
    """Step4主函数"""
    print("="*60)
    print("STEP 4: MOTIF-SEQUENCE MAPPING ANALYSIS")
    print("="*60)
    
    # 初始化映射器
    mapper = MotifSequenceMapper()
    
    # 加载阶段3的结果
    try:
        if not os.path.exists('stem_patterns_analysis.csv'):
            print("Error: stem_patterns_analysis.csv not found. Please run stage 3 first.")
            return
            
        pattern_df = pd.read_csv('stem_patterns_analysis.csv')
        print(f"Loaded {len(pattern_df)} patterns from stage 3")
        
        # 使用 run 方法执行完整流程
        success = mapper.run(pattern_df.to_dict('records'))
        
        if success:
            print("\nMotif mapping completed successfully!")
        else:
            print("\nMotif mapping failed!")
        
    except Exception as e:
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()