import os
import sys
import logging
import subprocess
import shutil
from pathlib import Path
import importlib.util
import numpy as np

class StemSagePipeline:
    def __init__(self, args):
        self.args = args
        self.output_dir = Path(args.out)
        self.setup_logging()
        
        # Set the file path (all in the output directory)
        self.positive_bed = args.positive_bed
        self.negative_bed = args.negative_bed
        self.positive_fasta = args.positive_fasta
        self.negative_fasta = args.negative_fasta
        self.genome_fa = args.genome_fa
        self.extend_bp = args.extend

        self.selected_model = getattr(args, 'model', 'xgboost')
        self.test_size = getattr(args, 'test_size', 0.2)
        self.random_state = getattr(args, 'random_state', 42)
        self.cv_folds = getattr(args, 'cv_folds', 5)
        self.threads = getattr(args, 'threads', 1)
        
        self.max_motifs = getattr(args, 'max_motifs', 5)
        self.similar_matching = getattr(args, 'similar_matching', True)
        self.similarity_threshold = getattr(args, 'similarity_threshold', 0.9)
        self.max_length_diff = getattr(args, 'max_length_diff', 1)

        # Intermediate files (in the output directory)
        self.processed_positive_fasta = self.output_dir / "transformed_positive_sequences.fasta"
        self.generated_negative_fasta = self.output_dir / "matched_negative_sequences.fasta"
        self.positive_structures = self.output_dir / "positive_structures.txt"
        self.negative_structures = self.output_dir / "negative_structures.txt"
        
    def setup_logging(self):
        log_file = self.output_dir / 'Stemage_pipeline.log'
        
        #Create a custom log processor to capture child process output
        class SubprocessLogHandler(logging.Handler):
            def __init__(self, log_file_path):
                super().__init__()
                self.log_file_path = log_file_path
                
            def emit(self, record):
                log_entry = self.format(record)
                with open(self.log_file_path, 'a') as f:
                    f.write(log_entry + '\n')
        
        # Configure root logger
        root_logger = logging.getLogger()
        root_logger.setLevel(logging.INFO)
        
        # Clear existing processors
        for handler in root_logger.handlers[:]:
            root_logger.removeHandler(handler)
        
        #Create formatter
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        
        # File handler -used to record all logs
        file_handler = logging.FileHandler(log_file, mode='a')
        file_handler.setFormatter(formatter)
        root_logger.addHandler(file_handler)
        
        # Console handler -for output to the console
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(formatter)
        root_logger.addHandler(console_handler)
        
        self.logger = logging.getLogger(__name__)
        self.log_file = log_file
    
    def run_command(self, cmd, description):
        """Run the shell command and redirect the output to the log file"""
        self.logger.info(f"Running: {description}")
        self.logger.info(f"Command: {cmd}")
        
        try:
            # Run the command in the output directory and redirect stdout and stderr to the log file
            with open(self.log_file, 'a') as log_f:
                log_f.write(f"\n=== Command Output: {description} ===\n")
                log_f.write(f"Command: {cmd}\n")
                log_f.flush()
                
                result = subprocess.run(
                    cmd, 
                    shell=True, 
                    check=True,
                    cwd=str(self.output_dir),
                    stdout=log_f,
                    stderr=subprocess.STDOUT,
                    text=True
                )
                
            self.logger.info(f"Success: {description}")
            return True
        except subprocess.CalledProcessError as e:
            self.logger.error(f"Failed: {description}")
            self.logger.error(f"Error code: {e.returncode}")
            
            # Write error information to the log file
            with open(self.log_file, 'a') as log_f:
                log_f.write(f"Command failed with return code: {e.returncode}\n")
                if e.output:
                    log_f.write(f"Error output: {e.output}\n")
            
            return False
    
    
    def check_file_exists(self, file_path, description):
        """Check if the file exists"""
        if not os.path.exists(file_path):
            self.logger.error(f"{description} not found: {file_path}")
            return False
        self.logger.info(f"{description} found: {file_path}")
        return True
    
    def step1_bed_preprocess(self):
        """Section 1: BED file preprocessing and FASTA extraction"""
        self.logger.info("=== Section 1: BED Preprocessing ===")
        
        # Make sure the output directory exists
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        #Read template
        template_path = Path(__file__).parent.parent / 'stemsage'/ 'templates' / 'bedtofasta_preprocess_template.sh'
        with open(template_path, 'r') as f:
            template = f.read()
        
        # Replace placeholders with absolute paths
        script_content = template.replace('__INPUT_BED__', os.path.abspath(self.positive_bed)) \
                            .replace('__OUTPUT_FASTA__', os.path.abspath(self.processed_positive_fasta))  \
                            .replace('__REFERENCE_FA__', os.path.abspath(self.genome_fa)) \
                            .replace('__EXTEND_BP__', str(self.extend_bp))
        
        # Write final script
        script_path = self.output_dir / 'bedtofasta_preprocess.sh'
        with open(script_path, 'w') as f:
            f.write(script_content)
        
        os.chmod(script_path, 0o755)
        
        # Run the command in the output directory
        success = self.run_command(f'./{script_path.name}', 'BED preprocessing')
        
        # Check whether the output file is generated
        if success and os.path.exists(self.processed_positive_fasta):
            self.logger.info(f"Successfully generated {self.processed_positive_fasta}")
            file_size = os.path.getsize(self.processed_positive_fasta)
            self.logger.info(f"Output file size: {file_size} bytes")
            return True
        else:
            self.logger.error(f"Failed to generate {self.processed_positive_fasta}")
            return False
    
    def step2_generate_negative_set(self):
        """Section 2: Generate negative sequence set"""
        self.logger.info("=== Section 2: Generate Negative Set ===")
        
        # Make sure the output directory exists
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Determine the input file: if positive_fasta is provided, use it; otherwise use the result of BED preprocessing
        if hasattr(self.args, 'positive_fasta') and self.args.positive_fasta:
            input_positive_fasta = self.args.positive_fasta
            self.logger.info(f"Using provided positive FASTA: {input_positive_fasta}")
        else:
            input_positive_fasta = self.processed_positive_fasta
            self.logger.info(f"Using BED-processed positive FASTA: {input_positive_fasta}")

        # Copy the template script to the output directory
        template_path = Path(__file__).parent.parent / 'stemsage' / 'templates' / 'generate_negative_set_template.py'
        script_path = self.output_dir / 'generate_negative_set.py'
        
        with open(template_path, 'r') as f:
            template = f.read()

        # Replace placeholder -use absolute path
        script_content = template.replace('__POSITIVE_FASTA__', os.path.abspath(input_positive_fasta)) \
                            .replace('__GENOME_FA__', os.path.abspath(self.genome_fa)) \
                            .replace('__OUTPUT_FASTA__', os.path.abspath(self.generated_negative_fasta))
        
        with open(script_path, 'w') as f:
            f.write(script_content)
        
        # Run command in output directory -use absolute path
        cmd = f"python generate_negative_set.py \"{os.path.abspath(input_positive_fasta)}\" \"{os.path.abspath(self.genome_fa)}\" \"{os.path.abspath(self.generated_negative_fasta)}\""
        success = self.run_command(cmd, 'Generate negative sequences')
        
        if success and self.generated_negative_fasta.exists():
            return True
        else:
            self.logger.error("Failed to generate negative FASTA file")
            return False
    
    def step3_rnafold(self, positive_input, negative_input, positive_output, negative_output):
        """Section 3: RNA structure prediction"""
        self.logger.info("=== Section 3: RNA Structure Prediction ===")
        
        #Convert to Path object uniformly
        positive_input = Path(positive_input)
        negative_input = Path(negative_input)
        positive_output = Path(positive_output)
        negative_output = Path(negative_output)
        
        # Make sure the output directory exists
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Check if the input file exists
        if not self.check_file_exists(positive_input, 'Positive input FASTA'):
            return False
        if not self.check_file_exists(negative_input, 'Negative input FASTA'):
            return False
        
        # Read template
        template_path = Path(__file__).parent.parent / 'stemsage' / 'templates' / 'rnafold_template.sh'
        with open(template_path, 'r') as f:
            template = f.read()
        
        # Replace placeholders -use relative paths to the output directory
        script_content = template.replace('__POSITIVE_INPUT__', os.path.basename(positive_input)) \
                            .replace('__NEGATIVE_INPUT__', os.path.basename(negative_input)) \
                            .replace('__POSITIVE_OUTPUT__', os.path.basename(positive_output)) \
                            .replace('__NEGATIVE_OUTPUT__', os.path.basename(negative_output))
        
        #Write the final script to the output directory
        script_path = self.output_dir / 'rnafold_modified.sh'
        with open(script_path, 'w') as f:
            f.write(script_content)
        
        os.chmod(script_path, 0o755)
        
        # Copy input file to output directory
        positive_basename = os.path.basename(positive_input)
        negative_basename = os.path.basename(negative_input)
        
        # Fix: Use str() to convert Path objects to strings for comparison
        if not str(positive_input).startswith(str(self.output_dir)):
            shutil.copy2(positive_input, self.output_dir / positive_basename)
            # Update the path to the path in the output directory
            positive_input = self.output_dir / positive_basename
        
        if not str(negative_input).startswith(str(self.output_dir)):
            shutil.copy2(negative_input, self.output_dir / negative_basename)
            # Update the path to the path in the output directory
            negative_input = self.output_dir / negative_basename
        
        # Run RNAfold
        success = self.run_command(f'./{script_path.name}', 'RNA structure prediction')
        
        return success
    
    def step4_analysis_pipeline(self):
        """Step 4: Execute Step1-5 analysis process"""
        self.logger.info("=== Section 4: Running Analysis Pipeline (Step1-5) ===")
        
        # Check necessary input files
        if not self.check_file_exists(self.positive_structures, 'Positive structures file'):
            return False
        if not self.check_file_exists(self.negative_structures, 'Negative structures file'):
            return False
        
        # Dynamic import
        try:
            from steps.step1 import DataProcessor
            from steps.step2 import RNAStemClassifier
            from steps.step3 import StemPatternMiner
            from steps.step4 import MotifSequenceMapper
            from steps.step5 import MotifVisualizer
            
        except ImportError as e:
            self.logger.error(f"Import error: {e}")
            return False
        
        # Run Step1: Feature Extraction
        self.logger.info("Running Step1: Feature Extraction")
        try:
            step1 = DataProcessor(self.output_dir, threads=self.threads)
            df = step1.prepare_dataset(str(self.positive_structures), str(self.negative_structures))
            
            if df is None or len(df) == 0:
                self.logger.error("Step1 failed to generate feature dataset")
                return False
                
            self.logger.info(f"Step1 feature extraction completed using {self.threads} thread(s)")
            
        except Exception as e:
            self.logger.error(f"Step1 failed: {e}")
            import traceback
            traceback.print_exc()
            return False
        
        # Run Step2: Machine Learning
        self.logger.info("Running Step2: Machine Learning")
        try:
            step2 = RNAStemClassifier(self.output_dir, self.selected_model)
            step2.test_size = self.test_size
            step2.random_state = self.random_state
            step2.cv_folds = self.cv_folds
            step2_result = step2.run()
            if step2_result is False: 
                self.logger.warning("Step2 completed with warnings, continuing...")
        except Exception as e:
            self.logger.error(f"Step2 failed: {e}")
            return False
        
        self.logger.info(f"Using model: {self.selected_model}")
        self.logger.info(f"Test size: {self.test_size}, Random state: {self.random_state}")
        self.logger.info(f"CV folds: {self.cv_folds}, Threads: {self.threads}")
        
        # Run Step3: Pattern Mining
        self.logger.info("Running Step3: Pattern Mining")
        try:
            # Get parameters from pipeline parameters
            max_stem_length = getattr(self.args, 'max_stem_length', 5)
            min_stem_length = getattr(self.args, 'min_stem_length', 1)
            
            step3 = StemPatternMiner(
                output_dir=self.output_dir,
                max_stem_length=max_stem_length,
                min_stem_length=min_stem_length,
                threads=self.threads
            )
            
            pattern_analysis, positive_stems, negative_stems = step3.run(
                str(self.positive_structures), str(self.negative_structures)
            )
                           
        except Exception as e:
            self.logger.error(f"Step3 failed: {e}")
            return False
        
        # Run Step 4: Motif Mapping -now supports pattern matching after clustering
        self.logger.info("Running Step4: Motif Mapping")
        try:
            step4 = MotifSequenceMapper(
                self.output_dir, 
                max_motifs=self.max_motifs,
                similar_matching=self.similar_matching,
                similarity_threshold=self.similarity_threshold,
                max_length_diff=self.max_length_diff
            )
            step4_result = step4.run(pattern_analysis)
            if step4_result is False:
                self.logger.warning("Step4 completed with warnings, continuing...")
        except Exception as e:
            self.logger.error(f"Step4 failed: {e}")
            return False
        
        # Run Step5: Visualization
        self.logger.info("Running Step5: Visualization")
        try:
            step5 = MotifVisualizer(self.output_dir)
            step5.max_motifs = self.max_motifs
            step5_result = step5.run(cmd_args=self.args)
            if step5_result is False:
                self.logger.warning("Step5 completed with warnings, continuing...")
        except Exception as e:
            self.logger.error(f"Step5 failed: {e}")
            return False
        
        return True
    
    def determine_workflow(self):
        """Determine the workflow based on input parameters"""
        has_positive_bed = self.positive_bed and self.check_file_exists(self.positive_bed, 'Positive BED file')
        has_negative_bed = self.negative_bed and self.check_file_exists(self.negative_bed, 'Negative BED file')
        has_positive_fasta = self.positive_fasta and self.check_file_exists(self.positive_fasta, 'Positive FASTA file')
        has_negative_fasta = self.negative_fasta and self.check_file_exists(self.negative_fasta, 'Negative FASTA file')
        
        self.logger.info(f"Input status - Positive BED: {has_positive_bed}, Negative BED: {has_negative_bed}")
        self.logger.info(f"Input status - Positive FASTA: {has_positive_fasta}, Negative FASTA: {has_negative_fasta}")
        
        # Case 1: positive.bed → run 1-2-3-4
        if has_positive_bed and not has_negative_bed and not has_positive_fasta:
            self.logger.info("Workflow 1: positive.bed → Run 1-2-3-4")
            return self.workflow1()
        
        # Case 2: positive.bed, negative.bed → run 1, then run 3
        elif has_positive_bed and has_negative_bed and not has_positive_fasta:
            self.logger.info("Workflow 2: positive.bed, negative.bed → Run 1, then 3")
            return self.workflow2()
        
        # Case 3: positive.fasta → run 2 to generate negative.fasta, and then run 3
        elif has_positive_fasta and not has_negative_fasta:
            self.logger.info("Workflow 3: positive.fasta → Run 2, then 3")
            return self.workflow3()
        
        # Case 4: positive.fasta, negative.fasta → run 3 directly
        elif has_positive_fasta and has_negative_fasta:
            self.logger.info("Workflow 4: positive.fasta, negative.fasta → Run 3 directly")
            return self.workflow4()
        
        else:
            self.logger.error("Invalid input combination. Please check your input files.")
            return False
    
    def workflow1(self):
        """Workflow 1: positive.bed → Run 1-2-3-4"""
        # Step 1: BED preprocessing
        if not self.step1_bed_preprocess():
            return False
        
        # Step 2: Generate negative set
        if not self.step2_generate_negative_set():
            return False
        
        # Step 3: RNA structure prediction
        if not self.step3_rnafold(self.processed_positive_fasta, self.generated_negative_fasta,
                                self.positive_structures, self.negative_structures):
            return False
        
        # Step 4: Analysis process
        return self.step4_analysis_pipeline()
    
    def workflow2(self):
        """Workflow 2: positive.bed, negative.bed → run 1, then run 3"""
        # Handle positive BED
        if not self.step1_bed_preprocess():
            return False
        
        # Handle negative BED
        self.logger.info("=== Negative BED Preprocessing ===")
        
        #Read template
        template_path = Path(__file__).parent.parent / 'stemsage' / 'templates' / 'bedtofasta_preprocess_template.sh'
        with open(template_path, 'r') as f:
            template = f.read()
        
        negative_processed_fasta = self.output_dir / "matched_negative_sequences.fasta"
        
        negative_script_content = template.replace('__INPUT_BED__', os.path.abspath(self.negative_bed)) \
                                        .replace('__OUTPUT_FASTA__', os.path.abspath(negative_processed_fasta)) \
                                        .replace('__REFERENCE_FA__', os.path.abspath(self.genome_fa)) \
                                        .replace('__EXTEND_BP__', str(self.extend_bp))
        
       # Write the negative processing script to the output directory
        script_path = self.output_dir / 'bedtofasta_preprocess_negative.sh'
        with open(script_path, 'w') as f:
            f.write(negative_script_content)
        
        os.chmod(script_path, 0o755)
        
        # Run negative BED preprocessing
        if not self.run_command(f'./{script_path.name}', 'Negative BED preprocessing'):
            return False
        
        # RNA structure prediction
        if not self.step3_rnafold(str(self.processed_positive_fasta), str(negative_processed_fasta),
                                str(self.positive_structures), str(self.negative_structures)):
            return False
        
        # Analysis process
        return self.step4_analysis_pipeline()
    
    def workflow3(self):
        """Workflow 3: positive.fasta → Run 2 to generate negative.fasta, then run 3"""
        # Generate negative set
        if not self.step2_generate_negative_set():
            return False
        
        # RNA structure prediction -using user-provided positive_fasta and generated negative_fasta
        if not self.step3_rnafold(str(self.positive_fasta), str(self.generated_negative_fasta),
                                str(self.positive_structures), str(self.negative_structures)):
            return False
        
        # Analysis process
        return self.step4_analysis_pipeline()

    def workflow4(self):
        """Workflow 4: positive.fasta, negative.fasta → run directly 3"""
        # RNA structure prediction -need to ensure that the input file is in the output directory
        if not self.step3_rnafold(str(self.positive_fasta), str(self.negative_fasta),
                                str(self.positive_structures), str(self.negative_structures)):
            return False
        
        # Analysis process
        return self.step4_analysis_pipeline()
    
    def run(self):
        """Run the entire process"""
        self.logger.info("Starting RNA Stem-loop Analysis Pipeline")
        self.logger.info(f"Genome reference: {self.genome_fa}")
        self.logger.info(f"Extension: {self.extend_bp} bp")
        
        success = self.determine_workflow()
        
        # Clean all .sh files
        self.logger.info("Cleaning up tmp files...")
        for sh_file in self.output_dir.glob("*.sh"):
            try:
                sh_file.unlink()
                self.logger.info(f"Removed: {sh_file}")
            except Exception as e:
                self.logger.warning(f"Failed to remove {sh_file}: {e}")
        
        if success:
            self.logger.info("Pipeline completed successfully!")
            print("\n" + "="*60)
            print("ANALYSIS PIPELINE COMPLETED SUCCESSFULLY!")
            print("="*60)
            print("Output files generated:")
            print("- positive_structures.txt, negative_structures.txt")
            print("- rna_features_dataset.csv")
            print("- stem_patterns_analysis.csv")
            print("- motif_visualization_report.html")
            print("- detailed motif information in ./Motif folder")
        else:
            self.logger.error("Pipeline failed!")
            print("\nPipeline failed! Check Stemage_pipeline.log for details.")
        
        return success