
import streamlit as st
import pandas as pd
import numpy as np

def read_csv(file, format_type):
    """Read CSV files with standard or mining format."""
    try:
        if format_type == "Standard CSV (Headers in row 1)":
            # Ensure we read the header correctly and handle potential parsing errors
            return pd.read_csv(file, header=0)
        else:  # Geological Survey Format
            df = pd.read_csv(file, header=None)
            header_row_idx = df[df.iloc[:, 0] == 'H1000'].index[0]
            headers = df.iloc[header_row_idx].values.tolist()
            data_rows = df[df.iloc[:, 0] == 'D']
            result_df = pd.DataFrame(data_rows.values, columns=headers)
            return result_df.iloc[:, 1:]
    except Exception as e:
        st.error(f"Error reading file: {str(e)}")
        return None


def process_collar_data(collar_file, format_type, mappings):
    """Process collar file using provided column mappings."""
    try:
        collar_df = read_csv(collar_file, format_type)
        if collar_df is not None:
            collar_df = collar_df.rename(columns={
                mappings['hole_id']: 'HOLE_ID',
                mappings['easting']: 'EASTING',
                mappings['northing']: 'NORTHING',
                mappings['elevation']: 'ELEVATION',
                mappings['dip']: 'DIP',
                mappings['azimuth']: 'AZIMUTH'
            })
            
            # Add a validation step
            required_cols = ['HOLE_ID', 'EASTING', 'NORTHING', 'ELEVATION', 'DIP', 'AZIMUTH']
            if not all(col in collar_df.columns for col in required_cols):
                missing = [col for col in required_cols if col not in collar_df.columns]
                st.error(f"Error: Collar data is missing required columns after mapping: {missing}. Please check your column selections.")
                return None

            for col in required_cols:
                if col != 'HOLE_ID': 
                    collar_df[col] = pd.to_numeric(collar_df[col], errors='coerce')
            
            return collar_df[required_cols]
    except Exception as e:
        st.error(f"Error processing collar file: {str(e)}")
        return None
    
def process_assay_data(assay_file, format_type, mappings, element_cols):
    """Process assay file using provided column mappings."""
    try:
        assay_df = read_csv(assay_file, format_type)
        if assay_df is not None:
            assay_df = assay_df.rename(columns={
                mappings['hole_id']: 'HOLE_ID',
                mappings['from']: 'FROM',
                mappings['to']: 'TO'
            })
            
            # Ensure the essential columns exist before proceeding
            required_base_cols = ['HOLE_ID', 'FROM', 'TO']
            if not all(col in assay_df.columns for col in required_base_cols):
                missing = [col for col in required_base_cols if col not in assay_df.columns]
                st.error(f"Error: Assay data is missing required columns after mapping: {missing}")
                return None, None

            numeric_cols = ['FROM', 'TO'] + element_cols
            for col in numeric_cols:
                # Check if element column exists before trying to process it
                if col in assay_df.columns:
                    assay_df[col] = assay_df[col].astype(str).str.replace('<', '-')
                    assay_df[col] = pd.to_numeric(assay_df[col], errors='coerce')
                    # Handle below-detection-limit values
                    assay_df.loc[assay_df[col] < 0, col] = abs(assay_df[col]) / 2

            # Return the dataframe and the list of elements that were actually found and processed
            final_element_cols = [col for col in element_cols if col in assay_df.columns]
            assay_df = assay_df[['HOLE_ID', 'FROM', 'TO'] + final_element_cols]
            return assay_df, final_element_cols # Return both the df and the actual columns
    except Exception as e:
        st.error(f"Error processing assay file: {str(e)}")
        return None, None

def process_litho_data(litho_file, format_type, mappings):
    """Process lithology file using provided column mappings."""
    try:
        litho_df = read_csv(litho_file, format_type)
        if litho_df is not None:
            litho_df = litho_df.rename(columns={
                mappings['hole_id']: 'HOLE_ID',
                mappings['from']: 'FROM',
                mappings['to']: 'TO',
                mappings['litho']: 'LITHO'
            })
            litho_df['FROM'] = pd.to_numeric(litho_df['FROM'], errors='coerce')
            litho_df['TO'] = pd.to_numeric(litho_df['TO'], errors='coerce')
            litho_df = litho_df[['HOLE_ID', 'FROM', 'TO', 'LITHO']]
            return litho_df
    except Exception as e:
        st.error(f"Error processing lithology file: {str(e)}")
        return None

def process_litho_dict(litho_dict_file, format_type, mappings):
    """Process lithology dictionary file using provided column mappings."""
    try:
        litho_dict_df = read_csv(litho_dict_file, format_type)
        if litho_dict_df is not None:
            code_col = mappings['code']
            desc_col = mappings['desc']
            litho_dict = dict(zip(litho_dict_df[code_col], litho_dict_df[desc_col]))
            return litho_dict
    except Exception as e:
        st.error(f"Error processing lithology dictionary file: {str(e)}")
        return None

# --- COMPOSITE AND MERGE FUNCTIONS ---
def composite_geochemical_data(df, element_cols, composite_length):
    """Create composites of geochemical intervals at a fixed length."""
    if 'HOLE_ID' not in df.columns or 'FROM' not in df.columns or 'TO' not in df.columns:
        st.error("DataFrame must have columns 'HOLE_ID', 'FROM', 'TO' for compositing.")
        return df

    composited_rows = []
    for hole_id, hole_data in df.groupby('HOLE_ID', sort=False):
        hole_data = hole_data.sort_values('FROM')
        hole_start = hole_data['FROM'].min()
        hole_end = hole_data['TO'].max()
        composite_top = hole_start

        while composite_top < hole_end:
            composite_bot = composite_top + composite_length
            if composite_bot > hole_end:
                composite_bot = hole_end
            overlap = hole_data[
                (hole_data['FROM'] < composite_bot) &
                (hole_data['TO'] > composite_top)
            ].copy()
            if overlap.empty:
                composite_top = composite_bot
                continue

            overlap['interval_start'] = overlap['FROM'].clip(lower=composite_top)
            overlap['interval_end'] = overlap['TO'].clip(upper=composite_bot)
            overlap['interval_length'] = overlap['interval_end'] - overlap['interval_start']

            composited_values = {}
            total_length = overlap['interval_length'].sum()

            for elem in element_cols:
                composited_values[elem] = np.average(overlap[elem], weights=overlap['interval_length'])

            composited_row = {
                'HOLE_ID': hole_id,
                'FROM': composite_top,
                'TO': composite_bot
            }
            composited_row.update(composited_values)
            composited_rows.append(composited_row)
            composite_top = composite_bot

    composite_df = pd.DataFrame(composited_rows)
    return composite_df

def process_and_merge_data(collar_df, assay_df, litho_df, element_cols, composite_enabled, composite_length):
    """Processes and merges collar, assay, and lithology data."""
    merged_df = None
    viz_litho_df = None
    
    if collar_df is not None:
        if assay_df is not None:
            if composite_enabled and element_cols:
                assay_df = composite_geochemical_data(assay_df, element_cols, composite_length)

            merged_df = pd.merge(collar_df, assay_df, on='HOLE_ID', how='inner')
            
            if 'FROM' in merged_df.columns and 'TO' in merged_df.columns:
                merged_df['MIDPOINT'] = (merged_df['FROM'] + merged_df['TO']) / 2
                merged_df['AZIMUTH_RAD'] = np.radians(90 - merged_df['AZIMUTH'])
                merged_df['DIP_RAD'] = np.radians(merged_df['DIP'])
                merged_df['dx'] = merged_df['MIDPOINT'] * np.cos(merged_df['DIP_RAD']) * np.cos(merged_df['AZIMUTH_RAD'])
                merged_df['dy'] = merged_df['MIDPOINT'] * np.cos(merged_df['DIP_RAD']) * np.sin(merged_df['AZIMUTH_RAD'])
                merged_df['dz'] = merged_df['MIDPOINT'] * np.sin(merged_df['DIP_RAD'])
                merged_df['x'] = merged_df['EASTING'] + merged_df['dx']
                merged_df['y'] = merged_df['NORTHING'] + merged_df['dy']
                merged_df['z'] = merged_df['ELEVATION'] + merged_df['dz']

        if litho_df is not None:
            viz_litho_df = pd.merge(litho_df, collar_df[['HOLE_ID','EASTING','NORTHING','ELEVATION','DIP','AZIMUTH']], on='HOLE_ID')
            viz_litho_df['MIDPOINT'] = (viz_litho_df['FROM'] + viz_litho_df['TO']) / 2
            viz_litho_df['AZIMUTH_RAD'] = np.radians(90 - viz_litho_df['AZIMUTH'])
            viz_litho_df['DIP_RAD'] = np.radians(viz_litho_df['DIP'])
            viz_litho_df['dx'] = viz_litho_df['MIDPOINT'] * np.cos(viz_litho_df['DIP_RAD']) * np.cos(viz_litho_df['AZIMUTH_RAD'])
            viz_litho_df['dy'] = viz_litho_df['MIDPOINT'] * np.cos(viz_litho_df['DIP_RAD']) * np.sin(viz_litho_df['AZIMUTH_RAD'])
            viz_litho_df['dz'] = viz_litho_df['MIDPOINT'] * np.sin(viz_litho_df['DIP_RAD'])
            viz_litho_df['x'] = viz_litho_df['EASTING'] + viz_litho_df['dx']
            viz_litho_df['y'] = viz_litho_df['NORTHING'] + viz_litho_df['dy']
            viz_litho_df['z'] = viz_litho_df['ELEVATION'] + viz_litho_df['dz']
            
            if merged_df is None:
                merged_df = pd.merge(collar_df, litho_df, on='HOLE_ID', how='inner')
                merged_df['MIDPOINT'] = (merged_df['FROM'] + merged_df['TO']) / 2
                merged_df['AZIMUTH_RAD'] = np.radians(90 - merged_df['AZIMUTH'])
                merged_df['DIP_RAD'] = np.radians(merged_df['DIP'])
                merged_df['dx'] = merged_df['MIDPOINT'] * np.cos(merged_df['DIP_RAD']) * np.cos(merged_df['AZIMUTH_RAD'])
                merged_df['dy'] = merged_df['MIDPOINT'] * np.cos(merged_df['DIP_RAD']) * np.sin(merged_df['AZIMUTH_RAD'])
                merged_df['dz'] = merged_df['MIDPOINT'] * np.sin(merged_df['DIP_RAD'])
                merged_df['x'] = merged_df['EASTING'] + merged_df['dx']
                merged_df['y'] = merged_df['NORTHING'] + merged_df['dy']
                merged_df['z'] = merged_df['ELEVATION'] + merged_df['dz']
            else:
                def find_matching_litho(row, litho_data):
                    hole_lithos = litho_data[litho_data['HOLE_ID'] == row['HOLE_ID']]
                    if hole_lithos.empty: return None
                    
                    overlaps = hole_lithos[
                        ((hole_lithos['FROM'] <= row['FROM']) & (hole_lithos['TO'] > row['FROM'])) |
                        ((hole_lithos['FROM'] < row['TO']) & (hole_lithos['TO'] >= row['TO'])) |
                        ((hole_lithos['FROM'] >= row['FROM']) & (hole_lithos['TO'] <= row['TO']))
                    ]
                    
                    if overlaps.empty:
                        hole_lithos['distance'] = np.minimum(np.abs(hole_lithos['FROM'] - row['MIDPOINT']), np.abs(hole_lithos['TO'] - row['MIDPOINT']))
                        return hole_lithos.loc[hole_lithos['distance'].idxmin()]['LITHO']
                    else:
                        overlaps['overlap'] = np.minimum(overlaps['TO'], row['TO']) - np.maximum(overlaps['FROM'], row['FROM'])
                        return overlaps.loc[overlaps['overlap'].idxmax()]['LITHO']
                
                merged_df['LITHO'] = merged_df.apply(lambda row: find_matching_litho(row, litho_df), axis=1)
    
    return merged_df, viz_litho_df
