"""
ChalkML QUANTUM Engine
======================
Deterministic data compression via quantum feature collapse.

Core Concept:
- Multiple correlated columns → Single quantum feature
- Deterministic normalization → Schema-based encoding
- Reversible transformation → Decode back to original

Schema = Constitution (created once, applied forever)
Data = Feed (flows through schema rules)
"""

import json
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler


class QuantumSchema:
    """
    The Constitution: Defines how to collapse columns into quantum features.
    Created once during training, applied forever in production.
    """
    
    def __init__(self, version: str = "1.0"):
        self.version = version
        self.created = datetime.now().isoformat()
        self.quantum_features: Dict[str, Dict] = {}
        self.input_columns: List[str] = []
        self.target_column: Optional[str] = None
        
    def add_quantum_feature(
        self,
        name: str,
        components: List[str],
        weights: List[float],
        normalization: Dict[str, Dict[str, float]],
        index: int
    ):
        """Add a quantum feature definition to schema"""
        if len(components) != len(weights):
            raise ValueError(f"Components ({len(components)}) and weights ({len(weights)}) must match")
        
        # Normalize weights to sum to 1.0
        weight_sum = sum(weights)
        normalized_weights = [w / weight_sum for w in weights]
        
        self.quantum_features[name] = {
            "index": index,
            "components": components,
            "weights": normalized_weights,
            "normalization": normalization
        }
    
    def to_dict(self) -> Dict:
        """Serialize schema to dictionary"""
        return {
            "version": self.version,
            "created": self.created,
            "input_columns": self.input_columns,
            "target": self.target_column,
            "quantum_features": self.quantum_features,
            "checksum": self._compute_checksum()
        }
    
    def _compute_checksum(self) -> str:
        """Compute schema checksum for integrity verification"""
        schema_str = json.dumps(self.quantum_features, sort_keys=True)
        return hashlib.sha256(schema_str.encode()).hexdigest()[:16]
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'QuantumSchema':
        """Deserialize schema from dictionary"""
        schema = cls(version=data.get("version", "1.0"))
        schema.created = data.get("created", datetime.now().isoformat())
        schema.input_columns = data.get("input_columns", [])
        schema.target_column = data.get("target")
        schema.quantum_features = data.get("quantum_features", {})
        
        # Verify checksum
        stored_checksum = data.get("checksum")
        computed_checksum = schema._compute_checksum()
        if stored_checksum and stored_checksum != computed_checksum:
            raise ValueError(f"Schema checksum mismatch: {stored_checksum} != {computed_checksum}")
        
        return schema
    
    def save(self, path: Path):
        """Save schema to JSON file"""
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)
    
    @classmethod
    def load(cls, path: Path) -> 'QuantumSchema':
        """Load schema from JSON file"""
        with open(path, 'r') as f:
            data = json.load(f)
        return cls.from_dict(data)


class QuantumEngine:
    """
    Core engine for quantum feature operations.
    Handles: analysis, quantumization, decoding.
    """
    
    def __init__(self, workspace_path: Optional[str] = None):
        if workspace_path is None:
            workspace_path = Path.cwd()
        self.workspace_path = Path(workspace_path)
        self.chalkml_dir = self.workspace_path / ".chalkml"
        self.schema_dir = self.chalkml_dir / "quantum_schemas"
        self.schema_dir.mkdir(parents=True, exist_ok=True)
        
    def analyze_correlations(self, df: pd.DataFrame, threshold: float = 0.7) -> Dict[str, List[str]]:
        """
        Analyze column correlations to suggest quantum feature groups.
        Returns: {group_name: [column_list]}
        """
        # Only numeric columns
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        
        if len(numeric_cols) < 2:
            return {"q_feature_0": numeric_cols}
        
        # Compute correlation matrix
        corr_matrix = df[numeric_cols].corr().abs()
        
        # Group highly correlated columns
        groups = {}
        assigned = set()
        group_idx = 0
        
        for col in numeric_cols:
            if col in assigned:
                continue
            
            # Find columns correlated with this one
            correlated = corr_matrix[col][corr_matrix[col] >= threshold].index.tolist()
            correlated = [c for c in correlated if c not in assigned]
            
            if len(correlated) > 0:
                group_name = f"q_group_{group_idx}"
                groups[group_name] = correlated
                assigned.update(correlated)
                group_idx += 1
        
        # Assign uncorrelated columns individually
        for col in numeric_cols:
            if col not in assigned:
                groups[f"q_feature_{col}"] = [col]
                assigned.add(col)
        
        return groups
    
    def create_schema_from_data(
        self,
        df: pd.DataFrame,
        auto_group: bool = True,
        correlation_threshold: float = 0.7,
        target_column: Optional[str] = None,
        custom_groups: Optional[Dict[str, List[str]]] = None
    ) -> QuantumSchema:
        """
        Create quantum schema from training data.
        
        This is the CONSTITUTION creation phase.
        """
        schema = QuantumSchema()
        
        # Store input columns (excluding target)
        all_cols = df.columns.tolist()
        if target_column and target_column in all_cols:
            schema.target_column = target_column
            all_cols.remove(target_column)
        schema.input_columns = all_cols
        
        # Determine quantum feature groups
        if custom_groups:
            groups = custom_groups
        elif auto_group:
            groups = self.analyze_correlations(df[all_cols], threshold=correlation_threshold)
        else:
            # Single feature per column
            groups = {f"q_{col}": [col] for col in all_cols if df[col].dtype in [np.float64, np.int64, np.float32, np.int32]}
        
        # Create quantum features
        quantum_idx = 0
        for group_name, columns in groups.items():
            # Filter to numeric columns only
            numeric_cols = [col for col in columns if col in df.columns and df[col].dtype in [np.float64, np.int64, np.float32, np.int32]]
            
            if len(numeric_cols) == 0:
                continue
            
            # Equal weights by default
            weights = [1.0 / len(numeric_cols)] * len(numeric_cols)
            
            # Compute normalization parameters (min-max for determinism)
            normalization = {}
            for col in numeric_cols:
                col_data = df[col].dropna()
                normalization[col] = {
                    "min": float(col_data.min()),
                    "max": float(col_data.max()),
                    "mean": float(col_data.mean()),
                    "std": float(col_data.std())
                }
            
            # Add to schema
            schema.add_quantum_feature(
                name=group_name,
                components=numeric_cols,
                weights=weights,
                normalization=normalization,
                index=quantum_idx
            )
            quantum_idx += 1
        
        return schema
    
    def quantumize_dataframe(self, df: pd.DataFrame, schema: QuantumSchema) -> pd.DataFrame:
        """
        Apply schema to dataframe → produce quantumized dataframe.
        
        This is the FEED phase. Data flows through the constitution.
        DETERMINISTIC: Same df + Same schema = Same output (always)
        """
        result_data = {}
        
        # Preserve ID columns (non-numeric)
        for col in df.columns:
            if df[col].dtype == 'object' or col == 'id' or '_id' in col.lower():
                result_data[col] = df[col].values
        
        # Apply each quantum feature transformation
        for q_name in sorted(schema.quantum_features.keys(), 
                            key=lambda x: schema.quantum_features[x]['index']):
            q_def = schema.quantum_features[q_name]
            components = q_def['components']
            weights = q_def['weights']
            norm_params = q_def['normalization']
            
            # Initialize quantum feature values
            q_values = np.zeros(len(df))
            
            # Collapse components into quantum feature
            for component, weight in zip(components, weights):
                if component not in df.columns:
                    raise ValueError(f"Component '{component}' not found in data. Schema mismatch.")
                
                # Get raw values
                raw_values = df[component].values
                
                # Normalize using schema parameters (min-max normalization)
                min_val = norm_params[component]['min']
                max_val = norm_params[component]['max']
                
                # Handle division by zero
                if max_val - min_val == 0:
                    normalized_values = np.zeros_like(raw_values, dtype=float)
                else:
                    normalized_values = (raw_values - min_val) / (max_val - min_val)
                
                # Clip to [0, 1] range
                normalized_values = np.clip(normalized_values, 0.0, 1.0)
                
                # Add weighted contribution
                q_values += weight * normalized_values
            
            result_data[q_name] = q_values
        
        # Preserve target column if present
        if schema.target_column and schema.target_column in df.columns:
            result_data[schema.target_column] = df[schema.target_column].values
        
        return pd.DataFrame(result_data)
    
    def decode_quantum_dataframe(self, q_df: pd.DataFrame, schema: QuantumSchema) -> pd.DataFrame:
        """
        Reverse transformation: quantum features → original columns (approximation).
        
        Note: This is lossy if multiple columns were collapsed.
        Best effort reconstruction using inverse weighted sum.
        """
        result_data = {}
        
        # Preserve ID columns
        for col in q_df.columns:
            if not col.startswith('q_') and col != schema.target_column:
                result_data[col] = q_df[col].values
        
        # Reconstruct original columns
        reconstructed = {}
        for q_name, q_def in schema.quantum_features.items():
            if q_name not in q_df.columns:
                continue
            
            q_values = q_df[q_name].values
            components = q_def['components']
            weights = q_def['weights']
            norm_params = q_def['normalization']
            
            # Simple reconstruction: divide quantum value by weight to get normalized value
            for component, weight in zip(components, weights):
                if weight == 0:
                    continue
                
                # Estimate normalized value (this is approximate for multi-component features)
                normalized_est = q_values * weight
                
                # Denormalize
                min_val = norm_params[component]['min']
                max_val = norm_params[component]['max']
                reconstructed_values = normalized_est * (max_val - min_val) + min_val
                
                reconstructed[component] = reconstructed_values
        
        result_data.update(reconstructed)
        
        # Add target if present
        if schema.target_column and schema.target_column in q_df.columns:
            result_data[schema.target_column] = q_df[schema.target_column].values
        
        return pd.DataFrame(result_data)
    
    def quantumize_file(
        self,
        input_path: str,
        output_path: str,
        schema_path: Optional[str] = None,
        auto_group: bool = True,
        target_column: Optional[str] = None
    ) -> Tuple[bool, str, Optional[str]]:
        """
        High-level: Quantumize a CSV file.
        
        Returns: (success, message, schema_path)
        """
        try:
            # Load data
            df = pd.read_csv(input_path)
            
            if schema_path:
                # Use existing schema (PRODUCTION MODE)
                schema = QuantumSchema.load(Path(schema_path))
                message = f"✅ Loaded schema: {schema_path}\n"
            else:
                # Create new schema (TRAINING MODE)
                schema = self.create_schema_from_data(
                    df, 
                    auto_group=auto_group,
                    target_column=target_column
                )
                
                # Save schema
                schema_filename = f"quantum_schema_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
                schema_path_obj = self.schema_dir / schema_filename
                schema.save(schema_path_obj)
                schema_path = str(schema_path_obj)
                
                message = f"✅ Created quantum schema: {schema_path}\n"
                message += f"   Features: {len(schema.quantum_features)}\n"
                for q_name, q_def in schema.quantum_features.items():
                    message += f"   - {q_name}: {len(q_def['components'])} columns collapsed\n"
            
            # Quantumize
            q_df = self.quantumize_dataframe(df, schema)
            
            # Save output
            q_df.to_csv(output_path, index=False)
            
            # Stats
            orig_cols = len(df.columns)
            quantum_cols = len(q_df.columns)
            compression = (1 - quantum_cols / orig_cols) * 100 if orig_cols > 0 else 0
            
            message += f"\n✅ Quantumized: {input_path} → {output_path}\n"
            message += f"   Original: {orig_cols} columns\n"
            message += f"   Quantum: {quantum_cols} columns\n"
            message += f"   Compression: {compression:.1f}%\n"
            message += f"   Rows: {len(df)}\n"
            
            return True, message, schema_path
            
        except Exception as e:
            return False, f"❌ Quantumization failed: {str(e)}", None
    
    def analyze_file(self, input_path: str, correlation_threshold: float = 0.7) -> Tuple[bool, str]:
        """Analyze a file and suggest quantum feature groups"""
        try:
            df = pd.read_csv(input_path)
            groups = self.analyze_correlations(df, threshold=correlation_threshold)
            
            message = f"📊 Quantum Feature Analysis: {input_path}\n"
            message += f"{'='*60}\n\n"
            
            message += f"Recommended Groups (correlation ≥ {correlation_threshold}):\n\n"
            
            for group_name, columns in groups.items():
                message += f"🔹 {group_name}:\n"
                for col in columns:
                    message += f"   - {col}\n"
                message += "\n"
            
            message += f"\nOriginal columns: {len(df.columns)}\n"
            message += f"Quantum features: {len(groups)}\n"
            compression = (1 - len(groups) / len(df.columns)) * 100
            message += f"Potential compression: {compression:.1f}%\n"
            
            return True, message
            
        except Exception as e:
            return False, f"❌ Analysis failed: {str(e)}"


def get_quantum_engine(workspace_path: Optional[str] = None) -> QuantumEngine:
    """Get quantum engine instance for workspace"""
    return QuantumEngine(workspace_path)
