import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.metrics import (roc_auc_score, precision_recall_curve, classification_report, 
                           confusion_matrix, average_precision_score, roc_curve)
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import json
from collections import Counter
from scipy import stats
import warnings
import argparse
from pathlib import Path

warnings.filterwarnings('ignore')

# 尝试导入SHAP，如果失败则提供替代方案
try:
    import shap
    SHAP_AVAILABLE = True
except ImportError:
    print("SHAP not available, using alternative feature importance methods")
    SHAP_AVAILABLE = False

class RNAStemClassifier:
    def __init__(self, output_dir=".", selected_model='xgboost'):
        self.output_dir = Path(output_dir)
        self.models = {}
        self.feature_importance = {}
        self.scaler = None
        self.best_params = {}
        self.selected_model = selected_model  # 用户选择的模型
        self.final_model = None  # 最终导出的模型
        self.test_size = 0.2
        self.random_state = 42
        self.cv_folds = 5
        
    def run(self):
        """运行完整的机器学习流程 - 这是pipeline.py调用的方法"""
        print("Starting RNA Stem Classifier training...")
        
        # 加载特征数据
        feature_file = self.output_dir / 'rna_features_dataset.csv'
        try:
            df = pd.read_csv(feature_file)
            print(f"Loaded dataset with shape: {df.shape}")
        except FileNotFoundError:
            print(f"Error: Feature file not found at {feature_file}")
            return False
        
        # 检查数据
        if 'label' not in df.columns:
            print("Error: 'label' column not found in dataset!")
            return False
        
        # 准备训练数据
        X_train, X_test, y_train, y_test, feature_names = self.prepare_training_data(
            df, test_size=self.test_size, random_state=self.random_state
        )
        
        # 训练选定的模型
        final_model = self.train_selected_model(X_train, X_test, y_train, y_test, feature_names)
        
        # 交叉验证
        cv_scores = self.cross_validation(X_train, y_train, cv_folds=self.cv_folds)
        
        # 分析阳性和阴性序列差异
        feature_diff_df = self.analyze_positive_negative_differences(df, feature_names)
        
        # SHAP分析
        shap_df = self.shap_analysis(X_test, feature_names)
        
        # 创建可视化
        self.create_comprehensive_plots(X_test, y_test, feature_names, feature_diff_df)
        
        # 保存结果
        self.save_models_and_results(df, feature_names, feature_diff_df)
        
        print("\n" + "="*50)
        print("STAGE 2 COMPLETED SUCCESSFULLY!")
        print("="*50)
        
        # 输出关键发现
        print(f"\n=== {self.selected_model.upper()} MODEL KEY FINDINGS ===")
        if len(feature_diff_df) > 0:
            best_feature = feature_diff_df.iloc[0]
            print(f"Most discriminative feature: {best_feature['feature']}")
            print(f"  - Cohen's d: {best_feature['cohens_d']:.3f}")
            print(f"  - p-value: {best_feature['p_value']:.4f}")
        
        print(f"Cross-validation AUC: {np.mean(cv_scores):.4f}")
        print(f"Model saved as: {self.selected_model}_stem_classifier.pkl")
        
        return True
        
    def prepare_training_data(self, df, test_size=0.2, random_state=42):
        """准备训练数据，确保阳性和阴性样本的平衡"""
        print("Preparing training data with balanced positive/negative samples...")
        
        # 分离阳性和阴性样本
        positive_df = df[df['label'] == 1]
        negative_df = df[df['label'] == 0]
        
        print(f"Positive samples: {len(positive_df)}")
        print(f"Negative samples: {len(negative_df)}")
        
        # 如果阴性样本远多于阳性样本，进行下采样以保持平衡
        if len(negative_df) > 1.5 * len(positive_df):
            print("Performing undersampling of negative samples for balance...")
            negative_df_sampled = negative_df.sample(n=len(positive_df), random_state=random_state)
            balanced_df = pd.concat([positive_df, negative_df_sampled])
        else:
            balanced_df = df
        
        print(f"Balanced dataset: {len(balanced_df)} samples")
        
        # 准备特征和标签
        feature_cols = [col for col in balanced_df.columns if col != 'label']
        X = balanced_df[feature_cols]
        
        # 数据预处理：确保所有特征都是数值类型
        X = self._preprocess_features(X)
        
        # 数据标准化
        self.scaler = StandardScaler()
        X_scaled = self.scaler.fit_transform(X)
        
        # 分层划分训练集和测试集，保持阳阴性比例
        X_train, X_test, y_train, y_test = train_test_split(
            X_scaled, balanced_df['label'].values, 
            test_size=test_size, random_state=random_state, 
            stratify=balanced_df['label'].values, shuffle=True
        )
        
        print(f"Training set: {X_train.shape[0]} samples ({np.sum(y_train==1)} positive, {np.sum(y_train==0)} negative)")
        print(f"Test set: {X_test.shape[0]} samples ({np.sum(y_test==1)} positive, {np.sum(y_test==0)} negative)")
        
        return X_train, X_test, y_train, y_test, feature_cols

    def _preprocess_features(self, X_data):
        """预处理特征数据确保数值类型"""
        print("Preprocessing features to ensure numeric types...")
        
        X_processed = X_data.copy()
        
        for col in X_processed.columns:
            if X_processed[col].dtype == 'object':
                try:
                    # 首先尝试直接转换为数值
                    X_processed[col] = pd.to_numeric(X_processed[col], errors='coerce')
                    
                    # 检查转换后的NaN比例
                    nan_ratio = X_processed[col].isna().sum() / len(X_processed)
                    if nan_ratio > 0.3:
                        print(f"Column {col} has {nan_ratio:.1%} NaN after conversion, using label encoding")
                        # 如果NaN太多，使用标签编码
                        X_processed[col] = X_processed[col].astype('category').cat.codes
                    elif nan_ratio > 0:
                        # 如果有少量NaN，用中位数填充
                        median_val = X_processed[col].median()
                        X_processed[col] = X_processed[col].fillna(median_val)
                        print(f"Filled {X_processed[col].isna().sum()} NaN values in column {col} with median {median_val}")
                        
                except Exception as e:
                    print(f"Error processing column {col}: {e}, using label encoding")
                    X_processed[col] = X_processed[col].astype('category').cat.codes
        
        # 最终检查并填充任何剩余的NaN
        if X_processed.isna().any().any():
            print("Filling remaining NaN values with 0")
            X_processed = X_processed.fillna(0)
        
        return X_processed
    
    def train_selected_model(self, X_train, X_test, y_train, y_test, feature_names):
        """训练用户选择的模型"""
        print(f"\n" + "="*50)
        print(f"TRAINING {self.selected_model.upper()} MODEL")
        print("="*50)
        
        if self.selected_model == 'xgboost':
            model = self._train_xgboost(X_train, X_test, y_train, y_test, feature_names)
        elif self.selected_model == 'random_forest':
            model = self._train_random_forest(X_train, X_test, y_train, y_test, feature_names)
        else:
            raise ValueError(f"Unsupported model: {self.selected_model}")
        
        self.final_model = model
        return model
    
    def _train_xgboost(self, X_train, X_test, y_train, y_test, feature_names):
        """训练XGBoost模型"""
        # XGBoost参数
        xgb_params = {
            'n_estimators': 300,
            'max_depth': 6,
            'learning_rate': 0.1,
            'subsample': 0.8,
            'colsample_bytree': 0.8,
            'reg_alpha': 0.1,
            'reg_lambda': 0.1,
            'random_state': 42,
            'eval_metric': 'logloss'
        }
        
        # 计算类别权重
        scale_pos_weight = len(y_train[y_train==0]) / len(y_train[y_train==1])
        xgb_params['scale_pos_weight'] = scale_pos_weight
        
        print(f"Using scale_pos_weight: {scale_pos_weight:.2f}")
        
        # 训练模型
        xgb_model = xgb.XGBClassifier(**xgb_params)
        xgb_model.fit(X_train, y_train, 
                     eval_set=[(X_train, y_train), (X_test, y_test)],
                     verbose=50)
        
        # 预测和评估
        y_pred_proba = xgb_model.predict_proba(X_test)[:, 1]
        y_pred = xgb_model.predict(X_test)
        
        # 评估指标
        self._evaluate_model("XGBoost", y_test, y_pred, y_pred_proba, xgb_model, feature_names)
        
        self.models['xgboost'] = xgb_model
        self.feature_importance['xgboost'] = xgb_model.feature_importances_
        
        return xgb_model
    
    def _train_random_forest(self, X_train, X_test, y_train, y_test, feature_names):
        """训练随机森林模型"""
        rf_params = {
            'n_estimators': 200,
            'max_depth': 15,
            'min_samples_split': 2,
            'min_samples_leaf': 1,
            'random_state': 42,
            'class_weight': 'balanced'
        }
        
        rf_model = RandomForestClassifier(**rf_params)
        rf_model.fit(X_train, y_train)
        
        y_pred_proba = rf_model.predict_proba(X_test)[:, 1]
        y_pred = rf_model.predict(X_test)
        
        self._evaluate_model("Random Forest", y_test, y_pred, y_pred_proba, rf_model, feature_names)
        
        self.models['random_forest'] = rf_model
        self.feature_importance['random_forest'] = rf_model.feature_importances_
        
        return rf_model
    
    def _evaluate_model(self, model_name, y_test, y_pred, y_pred_proba, model, feature_names):
        """详细评估模型性能"""
        print(f"\n--- {model_name} Evaluation ---")
        
        # 基础指标
        auc_score = roc_auc_score(y_test, y_pred_proba)
        ap_score = average_precision_score(y_test, y_pred_proba)
        
        print(f"AUC Score: {auc_score:.4f}")
        print(f"Average Precision: {ap_score:.4f}")
        
        # 分类报告
        print("\nClassification Report:")
        print(classification_report(y_test, y_pred, target_names=['Negative', 'Positive']))
        
        # 混淆矩阵
        cm = confusion_matrix(y_test, y_pred)
        print(f"Confusion Matrix:")
        print(f"True Negatives: {cm[0,0]}, False Positives: {cm[0,1]}")
        print(f"False Negatives: {cm[1,0]}, True Positives: {cm[1,1]}")
        
        # 计算特异性指标
        specificity = cm[0,0] / (cm[0,0] + cm[0,1]) if (cm[0,0] + cm[0,1]) > 0 else 0
        sensitivity = cm[1,1] / (cm[1,0] + cm[1,1]) if (cm[1,0] + cm[1,1]) > 0 else 0
        print(f"Sensitivity (Recall): {sensitivity:.4f}")
        print(f"Specificity: {specificity:.4f}")
        
        # 特征重要性
        if hasattr(model, 'feature_importances_'):
            feature_imp = pd.DataFrame({
                'feature': feature_names,
                'importance': model.feature_importances_
            }).sort_values('importance', ascending=False)
            
            print(f"\nTop 10 Important Features:")
            for i, row in feature_imp.head(10).iterrows():
                print(f"  {row['feature']}: {row['importance']:.4f}")
        
        return auc_score, ap_score
    
    def cross_validation(self, X, y, cv_folds=5):
        """交叉验证评估模型稳定性"""
        print("\n" + "="*50)
        print("CROSS-VALIDATION ANALYSIS")
        print("="*50)
        
        if self.selected_model == 'xgboost':
            model = xgb.XGBClassifier(
                n_estimators=100,
                max_depth=6,
                learning_rate=0.1,
                random_state=42
            )
        else:  # random_forest
            model = RandomForestClassifier(
                n_estimators=100,
                max_depth=15,
                random_state=42,
                class_weight='balanced'
            )
        
        cv_scores = cross_val_score(model, X, y, cv=cv_folds, scoring='roc_auc')
        
        print(f"Cross-validation AUC scores: {[f'{score:.4f}' for score in cv_scores]}")
        print(f"Mean AUC: {np.mean(cv_scores):.4f} (+/- {np.std(cv_scores)*2:.4f})")
        
        return cv_scores
    
    def analyze_positive_negative_differences(self, df, feature_names, top_n=15):
        """分析阳性和阴性序列在特征上的差异"""
        print("\n" + "="*50)
        print("POSITIVE vs NEGATIVE FEATURE ANALYSIS")
        print("="*50)
        
        positive_df = df[df['label'] == 1]
        negative_df = df[df['label'] == 0]
        
        print(f"Analyzing {len(feature_names)} features...")
        print(f"Positive samples: {len(positive_df)}, Negative samples: {len(negative_df)}")
        
        feature_differences = []
        
        for feature in feature_names:
            try:
                pos_values = positive_df[feature].dropna()
                neg_values = negative_df[feature].dropna()
                
                if len(pos_values) == 0 or len(neg_values) == 0:
                    continue
                    
                pos_mean = pos_values.mean()
                neg_mean = neg_values.mean()
                pos_std = pos_values.std()
                neg_std = neg_values.std()
                
                # t检验
                t_stat, p_value = stats.ttest_ind(pos_values, neg_values, equal_var=False)
                
                # 效应量 (Cohen's d)
                n1, n2 = len(pos_values), len(neg_values)
                pooled_std = np.sqrt(((n1-1)*pos_std**2 + (n2-1)*neg_std**2) / (n1 + n2 - 2))
                cohens_d = (pos_mean - neg_mean) / pooled_std if pooled_std != 0 else 0
                
                # 计算fold change，避免除零
                if neg_mean != 0:
                    fold_change = pos_mean / neg_mean
                else:
                    fold_change = float('inf') if pos_mean > 0 else 0
                
                feature_differences.append({
                    'feature': feature,
                    'positive_mean': pos_mean,
                    'negative_mean': neg_mean,
                    'mean_difference': pos_mean - neg_mean,
                    'cohens_d': cohens_d,
                    'p_value': p_value,
                    'fold_change': fold_change,
                    'abs_cohens_d': abs(cohens_d)
                })
            except Exception as e:
                print(f"Error analyzing feature {feature}: {e}")
                continue
        
        # 创建DataFrame并排序
        diff_df = pd.DataFrame(feature_differences)
        
        if len(diff_df) == 0:
            print("No features could be analyzed!")
            return pd.DataFrame()
        
        # 按效应量绝对值排序
        diff_df_sorted = diff_df.sort_values('abs_cohens_d', ascending=False)
        
        print(f"\nTop {top_n} Features with Largest Differences:")
        print("Feature".ljust(25) + "PosMean".ljust(10) + "NegMean".ljust(10) + 
              "Diff".ljust(10) + "Cohen's d".ljust(12) + "p-value".ljust(10))
        print("-" * 80)
        
        for _, row in diff_df_sorted.head(top_n).iterrows():
            print(f"{row['feature'][:24]:24} {row['positive_mean']:8.3f} {row['negative_mean']:8.3f} "
                  f"{row['mean_difference']:8.3f} {row['cohens_d']:10.3f} {row['p_value']:8.4f}")
        
        return diff_df_sorted
    
    def shap_analysis(self, X_test, feature_names):
        """使用SHAP进行模型可解释性分析"""
        print(f"\nPerforming SHAP analysis for {self.selected_model}...")
        
        output_file = self.output_dir / f'shap_summary_{self.selected_model}.png'

        if self.selected_model not in self.models:
            print(f"Model {self.selected_model} not found!")
            return None
        
        if not SHAP_AVAILABLE:
            print("SHAP not available. Using built-in feature importance instead.")
            return self._get_feature_importance_alternative(feature_names)
        
        model = self.models[self.selected_model]
        
        try:
            # 数据预处理：确保所有数据都是数值类型
            X_test_processed = self._preprocess_data_for_shap(X_test)

            # 修复SHAP兼容性问题
            if self.selected_model == 'xgboost':
                # 对于XGBoost，使用更兼容的方式
                explainer = shap.TreeExplainer(model, feature_perturbation="interventional")
            else:
                # 对于随机森林，使用抽样方法
                explainer = shap.KernelExplainer(model.predict_proba, X_test_processed[:100])
            
            # 计算SHAP值
            shap_values = explainer.shap_values(X_test_processed)
            
            # 处理SHAP值格式
            if isinstance(shap_values, list):
                shap_values = shap_values[1]  # 取正类的SHAP值
            
            # 特征重要性图
            plt.figure(figsize=(10, 8))
            shap.summary_plot(shap_values, X_test_processed, feature_names=feature_names, show=False)
            plt.title(f'SHAP Feature Importance - {self.selected_model.upper()}')
            plt.tight_layout()
            plt.savefig(output_file, dpi=300, bbox_inches='tight')
            plt.close()
            
            # 具体特征分析
            shap_df = pd.DataFrame({
                'feature': feature_names,
                'mean_abs_shap': np.mean(np.abs(shap_values), axis=0)
            }).sort_values('mean_abs_shap', ascending=False)
            
            print("\nTop Features by SHAP Importance:")
            for i, row in shap_df.head(10).iterrows():
                print(f"  {row['feature']}: {row['mean_abs_shap']:.4f}")
            
            return shap_df
            
        except Exception as e:
            print(f"SHAP analysis failed: {e}")
            print("Using alternative feature importance method...")
            return self._get_feature_importance_alternative(feature_names)
    
    def _preprocess_data_for_shap(self, X_data):
        """预处理数据以确保SHAP兼容性"""
        print("Preprocessing data for SHAP analysis...")
        
        # 转换为DataFrame以便处理
        if isinstance(X_data, np.ndarray):
            X_df = pd.DataFrame(X_data)
        else:
            X_df = X_data.copy()
        
        # 尝试将对象类型转换为数值类型
        for col in X_df.columns:
            if X_df[col].dtype == 'object':
                try:
                    # 尝试转换字符串为数值
                    X_df[col] = pd.to_numeric(X_df[col], errors='coerce')
                    
                    # 如果转换后有很多NaN，尝试其他转换方式
                    if X_df[col].isna().sum() > len(X_df) * 0.5:
                        print(f"Warning: Column {col} has many non-numeric values, using alternative encoding")
                        # 使用标签编码作为备选方案
                        X_df[col] = X_df[col].astype('category').cat.codes
                except Exception as e:
                    print(f"Error converting column {col}: {e}")
                    # 如果转换失败，使用简单的标签编码
                    X_df[col] = X_df[col].astype('category').cat.codes
        
        # 填充可能的NaN值
        X_df = X_df.fillna(0)
        
        return X_df.values if isinstance(X_data, np.ndarray) else X_df
    
    def _get_feature_importance_alternative(self, feature_names):
        """SHAP不可用时的替代方法"""
        if self.selected_model in self.feature_importance:
            importance_df = pd.DataFrame({
                'feature': feature_names,
                'importance': self.feature_importance[self.selected_model]
            }).sort_values('importance', ascending=False)
            
            print(f"\nTop Features by {self.selected_model.upper()} Importance:")
            for i, row in importance_df.head(10).iterrows():
                print(f"  {row['feature']}: {row['importance']:.4f}")
            
            return importance_df
        return None
    
    def create_comprehensive_plots(self, X_test, y_test, feature_names, feature_diff_df):
        """创建综合可视化"""
        print("\nCreating comprehensive visualizations...")
        
        output_file = self.output_dir / f'model_analysis_{self.selected_model}.png'
        
        plt.figure(figsize=(16, 12))
        
        # 1. ROC曲线
        plt.subplot(2, 2, 1)
        if self.final_model:
            y_pred_proba = self.final_model.predict_proba(X_test)[:, 1]
            fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
            auc_score = roc_auc_score(y_test, y_pred_proba)
            plt.plot(fpr, tpr, linewidth=2, label=f'{self.selected_model} (AUC = {auc_score:.3f})')
        
        plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
        plt.xlabel('False Positive Rate', fontsize=12)
        plt.ylabel('True Positive Rate', fontsize=12)
        plt.title(f'ROC Curve - {self.selected_model.upper()}', fontsize=14)
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
        
        # 2. 特征差异图
        plt.subplot(2, 2, 2)
        if len(feature_diff_df) > 0:
            top_features = feature_diff_df.head(10)
            y_pos = np.arange(len(top_features))
            colors = ['red' if x > 0 else 'blue' for x in top_features['cohens_d']]
            
            plt.barh(y_pos, top_features['cohens_d'], color=colors, alpha=0.7)
            plt.yticks(y_pos, [f[:20] for f in top_features['feature']])
            plt.xlabel("Cohen's d Effect Size", fontsize=12)
            plt.title('Top Feature Differences\n(Positive vs Negative)', fontsize=14)
            plt.grid(True, alpha=0.3)
        
        # 3. 模型特征重要性
        plt.subplot(2, 2, 3)
        if self.selected_model in self.feature_importance:
            imp_df = pd.DataFrame({
                'feature': feature_names,
                'importance': self.feature_importance[self.selected_model]
            }).sort_values('importance', ascending=False).head(10)
            
            plt.barh(range(len(imp_df)), imp_df['importance'], alpha=0.7)
            plt.yticks(range(len(imp_df)), [f[:20] for f in imp_df['feature']])
            plt.xlabel(f'{self.selected_model.upper()} Feature Importance', fontsize=12)
            plt.title(f'Top {self.selected_model.upper()} Features', fontsize=14)
            plt.grid(True, alpha=0.3)
        
        # 4. 精度-召回曲线
        plt.subplot(2, 2, 4)
        if self.final_model:
            y_pred_proba = self.final_model.predict_proba(X_test)[:, 1]
            precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
            ap_score = average_precision_score(y_test, y_pred_proba)
            plt.plot(recall, precision, linewidth=2, label=f'AP = {ap_score:.3f}')
            plt.xlabel('Recall', fontsize=12)
            plt.ylabel('Precision', fontsize=12)
            plt.title('Precision-Recall Curve', fontsize=14)
            plt.legend(fontsize=10)
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Visualizations saved to 'model_analysis_{self.selected_model}.png'")
    
    def save_models_and_results(self, df, feature_names, feature_diff_df):
        """保存模型和结果"""
        print("\nSaving models and results...")
        
        if not self.final_model:
            print("No model to save!")
            return
        
        # 保存选定的模型 - 修复：使用完整路径
        model_filename = self.output_dir / f'{self.selected_model}_stem_classifier.pkl'
        joblib.dump(self.final_model, model_filename)
        
        if self.scaler:
            scaler_filename = self.output_dir / 'feature_scaler.pkl'
            joblib.dump(self.scaler, scaler_filename)  # 修复：使用变量而不是字符串
        
        # 保存特征重要性
        importance_data = {
            'feature': feature_names,
            f'{self.selected_model}_importance': self.feature_importance.get(self.selected_model, [])
        }
        
        # 添加效应量信息
        if len(feature_diff_df) > 0:
            cohens_d_dict = feature_diff_df.set_index('feature')['cohens_d'].to_dict()
            importance_data['cohens_d'] = [cohens_d_dict.get(f, 0) for f in feature_names]
        
        importance_df = pd.DataFrame(importance_data)
        importance_file = self.output_dir / f'feature_importance_{self.selected_model}.csv'
        importance_df.to_csv(importance_file, index=False)  # 修复：使用变量而不是字符串
        
        # 保存结果总结
        results_summary = {
            'selected_model': self.selected_model,
            'positive_samples': len(df[df['label'] == 1]),
            'negative_samples': len(df[df['label'] == 0]),
            'total_features': len(feature_names),
            'shap_available': SHAP_AVAILABLE
        }
        
        summary_file = self.output_dir / f'training_summary_{self.selected_model}.json'
        with open(summary_file, 'w') as f:  # 修复：使用变量而不是字符串
            json.dump(results_summary, f, indent=2)
        
        print("Models and results saved successfully!")
        print(f"- {self.selected_model.upper()} model: {model_filename}")
        print(f"- Feature scaler: {scaler_filename}")
        print(f"- Feature importance: {importance_file}")
        print(f"- Training summary: {summary_file}")

def parse_arguments():
    """解析命令行参数"""
    parser = argparse.ArgumentParser(description='RNA Stem Classifier Training')
    parser.add_argument('--model', type=str, choices=['xgboost', 'random_forest'], 
                       default='xgboost', help='Model to train (default: xgboost)')
    parser.add_argument('--output-dir', type=str, default='.',
                       help='Output directory for results')
    return parser.parse_args()

# 主执行函数
def main():
    # 解析命令行参数
    args = parse_arguments()
    
    # 加载阶段1生成的特征数据
    print("Loading feature data from stage 1...")
    try:
        # 使用输出目录
        output_dir = Path(args.output_dir)
        feature_file = output_dir / 'rna_features_dataset.csv'
        df = pd.read_csv(feature_file)
        print(f"Loaded dataset with shape: {df.shape}")
    except FileNotFoundError:
        print(f"Error: Feature file not found at {feature_file}")
        return
    
    # 检查数据
    if 'label' not in df.columns:
        print("Error: 'label' column not found in dataset!")
        return
    
    # 初始化分类器，使用用户选择的模型
    classifier = RNAStemClassifier(output_dir=args.output_dir, selected_model=args.model)
    
    # 方法1：使用新的 run() 方法
    print(f"Using {args.model} model with run() method...")
    success = classifier.run()
    
    if success:
        print("\n" + "="*50)
        print("STAGE 2 COMPLETED SUCCESSFULLY!")
        print("="*50)
    else:
        print("\n" + "="*50)
        print("STAGE 2 FAILED!")
        print("="*50)

if __name__ == "__main__":
    main()