use std::sync::Arc;

use datafusion::arrow::compute::can_cast_types;
use datafusion::arrow::datatypes::{Field, FieldRef, Schema as ArrowSchema, SchemaRef};
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::physical_expr::expressions::{Column, Literal};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr_adapter::{PhysicalExprAdapter, PhysicalExprAdapterFactory};

#[derive(Debug)]
pub struct IcebergPhysicalExprAdapterFactory {}

impl PhysicalExprAdapterFactory for IcebergPhysicalExprAdapterFactory {
    fn create(
        &self,
        logical_file_schema: SchemaRef,
        physical_file_schema: SchemaRef,
    ) -> Arc<dyn PhysicalExprAdapter> {
        let (column_mapping, default_values) =
            create_column_mapping(&logical_file_schema, &physical_file_schema);

        Arc::new(IcebergPhysicalExprAdapter {
            logical_file_schema,
            physical_file_schema,
            partition_values: Vec::new(),
            column_mapping,
            default_values,
        })
    }
}

fn create_column_mapping(
    logical_schema: &ArrowSchema,
    physical_schema: &ArrowSchema,
) -> (Vec<Option<usize>>, Vec<Option<ScalarValue>>) {
    let mut column_mapping = Vec::with_capacity(logical_schema.fields().len());
    let mut default_values = Vec::with_capacity(logical_schema.fields().len());

    for logical_field in logical_schema.fields() {
        match physical_schema.index_of(logical_field.name()) {
            Ok(physical_index) => {
                column_mapping.push(Some(physical_index));
                default_values.push(None);
            }
            Err(_) => {
                column_mapping.push(None);
                let default_value = if logical_field.is_nullable() {
                    Some(
                        ScalarValue::try_from(logical_field.data_type())
                            .unwrap_or(ScalarValue::Null),
                    )
                } else {
                    Some(
                        ScalarValue::new_zero(logical_field.data_type())
                            .unwrap_or(ScalarValue::Null),
                    )
                };
                default_values.push(default_value);
            }
        }
    }

    (column_mapping, default_values)
}

#[derive(Debug)]
struct IcebergPhysicalExprAdapter {
    logical_file_schema: SchemaRef,
    physical_file_schema: SchemaRef,
    partition_values: Vec<(FieldRef, ScalarValue)>,
    column_mapping: Vec<Option<usize>>,
    default_values: Vec<Option<ScalarValue>>,
}

impl PhysicalExprAdapter for IcebergPhysicalExprAdapter {
    fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
        let rewriter = IcebergPhysicalExprRewriter {
            logical_file_schema: &self.logical_file_schema,
            physical_file_schema: &self.physical_file_schema,
            partition_values: &self.partition_values,
            column_mapping: &self.column_mapping,
            default_values: &self.default_values,
        };
        expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
            .data()
    }

    fn with_partition_values(
        &self,
        partition_values: Vec<(FieldRef, ScalarValue)>,
    ) -> Arc<dyn PhysicalExprAdapter> {
        Arc::new(IcebergPhysicalExprAdapter {
            logical_file_schema: Arc::clone(&self.logical_file_schema),
            physical_file_schema: Arc::clone(&self.physical_file_schema),
            partition_values,
            column_mapping: self.column_mapping.clone(),
            default_values: self.default_values.clone(),
        })
    }
}

impl Clone for IcebergPhysicalExprAdapter {
    fn clone(&self) -> Self {
        Self {
            logical_file_schema: Arc::clone(&self.logical_file_schema),
            physical_file_schema: Arc::clone(&self.physical_file_schema),
            partition_values: self.partition_values.clone(),
            column_mapping: self.column_mapping.clone(),
            default_values: self.default_values.clone(),
        }
    }
}

struct IcebergPhysicalExprRewriter<'a> {
    logical_file_schema: &'a ArrowSchema,
    physical_file_schema: &'a ArrowSchema,
    partition_values: &'a [(FieldRef, ScalarValue)],
    column_mapping: &'a [Option<usize>],
    default_values: &'a [Option<ScalarValue>],
}

impl<'a> IcebergPhysicalExprRewriter<'a> {
    fn rewrite_expr(
        &self,
        expr: Arc<dyn PhysicalExpr>,
    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
            return self.rewrite_column(Arc::clone(&expr), column);
        }
        Ok(Transformed::no(expr))
    }

    fn rewrite_column(
        &self,
        expr: Arc<dyn PhysicalExpr>,
        column: &Column,
    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
        if let Some(partition_value) = self.get_partition_value(column.name()) {
            return Ok(Transformed::yes(Arc::new(Literal::new(partition_value))));
        }

        let logical_field_index = match self.logical_file_schema.index_of(column.name()) {
            Ok(index) => index,
            Err(_) => {
                if let Ok(_physical_field) =
                    self.physical_file_schema.field_with_name(column.name())
                {
                    return Ok(Transformed::no(expr));
                } else {
                    return exec_err!(
                        "Column '{}' not found in either logical or physical schema",
                        column.name()
                    );
                }
            }
        };

        let logical_field = self.logical_file_schema.field(logical_field_index);

        match self.column_mapping.get(logical_field_index) {
            Some(Some(physical_index)) => {
                let physical_field = self.physical_file_schema.field(*physical_index);
                self.handle_existing_column(
                    expr,
                    column,
                    logical_field,
                    physical_field,
                    *physical_index,
                )
            }
            Some(None) => {
                if let Some(Some(default_value)) = self.default_values.get(logical_field_index) {
                    Ok(Transformed::yes(Arc::new(Literal::new(
                        default_value.clone(),
                    ))))
                } else if logical_field.is_nullable() {
                    let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?;
                    Ok(Transformed::yes(Arc::new(Literal::new(null_value))))
                } else {
                    exec_err!("Non-nullable column '{}' is missing from physical schema and no default value provided", column.name())
                }
            }
            None => exec_err!(
                "Column mapping not found for logical field index {}",
                logical_field_index
            ),
        }
    }

    fn handle_existing_column(
        &self,
        expr: Arc<dyn PhysicalExpr>,
        column: &Column,
        logical_field: &Field,
        physical_field: &Field,
        physical_index: usize,
    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
        let needs_index_update = column.index() != physical_index;
        let needs_type_cast = logical_field.data_type() != physical_field.data_type();

        match (needs_index_update, needs_type_cast) {
            (false, false) => Ok(Transformed::no(expr)),
            (true, false) => {
                let new_column =
                    Column::new_with_schema(logical_field.name(), self.physical_file_schema)?;
                Ok(Transformed::yes(Arc::new(new_column)))
            }
            (false, true) => self.apply_type_cast(expr, logical_field, physical_field),
            (true, true) => {
                let new_column =
                    Column::new_with_schema(logical_field.name(), self.physical_file_schema)?;
                self.apply_type_cast(Arc::new(new_column), logical_field, physical_field)
            }
        }
    }

    fn apply_type_cast(
        &self,
        column_expr: Arc<dyn PhysicalExpr>,
        logical_field: &Field,
        physical_field: &Field,
    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
        if !can_cast_types(physical_field.data_type(), logical_field.data_type()) {
            return exec_err!(
                "Cannot cast column '{}' from '{}' (physical) to '{}' (logical)",
                logical_field.name(),
                physical_field.data_type(),
                logical_field.data_type()
            );
        }
        let cast_expr = datafusion::physical_expr::expressions::CastExpr::new(
            column_expr,
            logical_field.data_type().clone(),
            None,
        );
        Ok(Transformed::yes(Arc::new(cast_expr)))
    }

    fn get_partition_value(&self, column_name: &str) -> Option<ScalarValue> {
        self.partition_values
            .iter()
            .find(|(field, _)| field.name() == column_name)
            .map(|(_, value)| value.clone())
    }
}
