import numpy as np
from typing import List, Tuple, Callable, Optional


def create_field_transform(
    b_field: np.ndarray,
    original_component_names: List[str],
    e_field: Optional[np.ndarray] = None,
) -> Callable[[np.ndarray], Tuple[np.ndarray, List[str]]]:
    """
    Creates a transformation function to align velocity coordinates with E and B fields.

    This function returns a new transformation function that can be passed to plotter methods.
    The returned function takes a data array and calculates new velocity components.

    - If only `b_field` is provided, it calculates velocity components parallel and
      perpendicular to the magnetic field (`v_parallel`, `v_perp`).
    - If both `b_field` and `e_field` are provided, it creates an orthonormal basis
      (v_B, v_E, v_BxE) and projects the velocity onto it.

    The function returns a new data array containing the original data *without* the
    old velocity components, but with the new field-aligned components appended.

    Args:
        b_field (np.ndarray): The magnetic field vector (3 components).
        original_component_names (List[str]): The list of component names for the input data array.
        e_field (np.ndarray, optional): The electric field vector (3 components).

    Returns:
        A transformation function that takes a NumPy array and returns a tuple containing
        the transformed data array and the new list of component names.
    """
    if b_field.shape != (3,):
        raise ValueError(
            f"b_field must be a 3-component vector, but got shape {b_field.shape}"
        )
    if e_field is not None and e_field.shape != (3,):
        raise ValueError(
            f"e_field must be a 3-component vector, but got shape {e_field.shape}"
        )

    def field_transform(data: np.ndarray) -> Tuple[np.ndarray, List[str]]:
        """The actual transformation function generated by the factory."""
        # Ensure all required velocity components are present
        vel_comps = ("velocity_x", "velocity_y", "velocity_z")

        indices = [original_component_names.index(comp) for comp in vel_comps]
        v = data[:, indices]

        # --- Perform the transformation ---
        b_hat = b_field / np.linalg.norm(b_field)
        if e_field is None:
            # --- B-field only transformation ---
            # Project velocity onto B
            v_parallel = np.dot(v, b_hat)

            v_mag_sq = np.sum(v**2, axis=1)
            v_perp_sq = v_mag_sq - v_parallel**2
            v_perp = np.sqrt(np.maximum(0, v_perp_sq))

            new_components = np.c_[v_parallel, v_perp]
            new_names = ["v_parallel", "v_perp"]
        else:
            # --- E-B field transformation ---
            # Define orthonormal basis
            # E perpendicular component
            E_perp = e_field - np.dot(e_field, b_hat) * b_hat
            norm_E_perp = np.linalg.norm(E_perp)

            if np.isclose(norm_E_perp, 0):
                raise ValueError(
                    "E field is parallel to B field, cannot form an orthonormal basis."
                )

            e_perp_hat = E_perp / norm_E_perp

            # Third basis vector
            b_cross_e_hat = np.cross(b_hat, e_perp_hat)

            # Project velocity onto the new basis
            v_B = np.dot(v, b_hat)
            v_E_perp = np.dot(v, e_perp_hat)
            v_BxE_perp = np.dot(v, b_cross_e_hat)

            new_components = np.c_[v_B, v_E_perp, v_BxE_perp]
            new_names = ["v_B", "v_E", "v_BxE"]

        # --- Construct the new data array ---
        # Identify columns to keep (everything except the old velocities)
        keep_indices = [i for i in range(data.shape[1]) if i not in indices]

        # Create the new data array by combining the kept columns with the new ones
        kept_data = data[:, keep_indices]
        transformed_data = np.c_[kept_data, new_components]

        # --- Update component names ---
        kept_names = [
            name for i, name in enumerate(original_component_names) if i in keep_indices
        ]
        new_component_names = kept_names + new_names

        return transformed_data, new_component_names

    return field_transform
